Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/functions/export-import.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import type {
} from "../types.js";
import { normalizeAccessLog } from "./access-tracker.js";
import { KV } from "../state/schema.js";
import { indexGraphEdge, indexGraphNode } from "../state/graph-indexes.js";
import { StateKV } from "../state/kv.js";
import { VERSION } from "../version.js";
import { recordAudit } from "./audit.js";
Expand Down Expand Up @@ -404,6 +405,7 @@ export function registerExportImportFunction(sdk: ISdk, kv: StateKV): void {
if (existing) { stats.skipped++; continue; }
}
await kv.set(KV.graphNodes, node.id, node);
await indexGraphNode(kv, node);
}
}
if (importData.graphEdges) {
Expand All @@ -413,6 +415,7 @@ export function registerExportImportFunction(sdk: ISdk, kv: StateKV): void {
if (existing) { stats.skipped++; continue; }
}
await kv.set(KV.graphEdges, edge.id, edge);
await indexGraphEdge(kv, edge);
}
}
if (importData.semanticMemories) {
Expand Down
201 changes: 167 additions & 34 deletions src/functions/graph-retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import type {
} from "../types.js";
import { KV } from "../state/schema.js";
import type { StateKV } from "../state/kv.js";
import {
GraphIndexReader,
graphIndexesReady,
loadNameCatalog,
loadNodeIdsForObservations,
} from "../state/graph-indexes.js";

export interface GraphRetrievalResult {
obsId: string;
Expand All @@ -13,6 +19,10 @@ export interface GraphRetrievalResult {
pathLength: number;
}

type NeighborProvider = (
nodeId: string,
) => Promise<Array<{ node: GraphNode; edge: GraphEdge }>>;

function buildGraphContext(
path: Array<{ node: GraphNode; edge?: GraphEdge }>,
): string {
Expand All @@ -38,6 +48,31 @@ function buildGraphContext(
return parts.join(" ");
}

function neighborsFromArrays(
allNodes: GraphNode[],
allEdges: GraphEdge[],
): NeighborProvider {
const nodeIndex = new Map<string, GraphNode>();
for (const n of allNodes) nodeIndex.set(n.id, n);

const adjacency = new Map<
string,
Array<{ node: GraphNode; edge: GraphEdge }>
>();
const append = (from: string, to: string, edge: GraphEdge): void => {
const node = nodeIndex.get(to);
if (!node) return;
if (!adjacency.has(from)) adjacency.set(from, []);
adjacency.get(from)!.push({ node, edge });
};
for (const edge of allEdges) {
append(edge.sourceNodeId, edge.targetNodeId, edge);
append(edge.targetNodeId, edge.sourceNodeId, edge);
}

return async (nodeId) => adjacency.get(nodeId) ?? [];
}

export class GraphRetrieval {
constructor(private kv: StateKV) {}

Expand All @@ -46,6 +81,28 @@ export class GraphRetrieval {
maxDepth = 2,
maxResults = 20,
): Promise<GraphRetrievalResult[]> {
if (await graphIndexesReady(this.kv)) {
const reader = await GraphIndexReader.open(this.kv);
const catalog = await loadNameCatalog(this.kv);
const lowered = entityNames.map((e) => e.toLowerCase());
const matchingNodes: GraphNode[] = [];
for (const entry of catalog) {
const nameLower = entry.name.toLowerCase();
const matched = lowered.some(
(e) => nameLower.includes(e) || e.includes(nameLower),
);
if (!matched) continue;
const node = await reader.getNode(entry.id);
if (node) matchingNodes.push(node);
}
return this.scoreEntityMatches(
matchingNodes,
(id) => reader.getNeighbors(id),
maxDepth,
maxResults,
);
}

const allNodes = (await this.kv.list<GraphNode>(KV.graphNodes)).filter((n) => !n.stale);
const allEdges = (await this.kv.list<GraphEdge>(KV.graphEdges)).filter((e) => !e.stale);

Expand All @@ -58,16 +115,36 @@ export class GraphRetrieval {
);
});

return this.scoreEntityMatches(
matchingNodes,
neighborsFromArrays(allNodes, allEdges),
maxDepth,
maxResults,
);
}

private async scoreEntityMatches(
matchingNodes: GraphNode[],
getNeighbors: NeighborProvider,
maxDepth: number,
maxResults: number,
): Promise<GraphRetrievalResult[]> {
if (matchingNodes.length === 0) return [];

// Which start node first claims an observation decides its score,
// so iterate in a deterministic order regardless of whether the
// matches came from enumeration or the sharded name catalog.
const orderedMatches = [...matchingNodes].sort((a, b) =>
a.id.localeCompare(b.id),
);

const results: GraphRetrievalResult[] = [];
const visitedObs = new Set<string>();

for (const startNode of matchingNodes) {
const paths = this.dijkstraTraversal(
for (const startNode of orderedMatches) {
const paths = await this.dijkstraTraversal(
startNode,
allNodes,
allEdges,
getNeighbors,
maxDepth,
);

Expand Down Expand Up @@ -119,18 +196,60 @@ export class GraphRetrieval {
maxDepth = 1,
maxResults = 10,
): Promise<GraphRetrievalResult[]> {
if (await graphIndexesReady(this.kv)) {
const reader = await GraphIndexReader.open(this.kv);
const candidateIds = await loadNodeIdsForObservations(this.kv, obsIds);
const linkedNodes: GraphNode[] = [];
for (const nodeId of candidateIds) {
const node = await reader.getNode(nodeId);
if (
node &&
(node.sourceObservationIds ?? []).some((id) => obsIds.includes(id))
) {
linkedNodes.push(node);
}
}
return this.scoreExpansion(
linkedNodes,
(id) => reader.getNeighbors(id),
obsIds,
maxDepth,
maxResults,
);
}

const allNodes = (await this.kv.list<GraphNode>(KV.graphNodes)).filter((n) => !n.stale);
const allEdges = (await this.kv.list<GraphEdge>(KV.graphEdges)).filter((e) => !e.stale);

const linkedNodes = allNodes.filter((n) =>
n.sourceObservationIds.some((id) => obsIds.includes(id)),
);

return this.scoreExpansion(
linkedNodes,
neighborsFromArrays(allNodes, allEdges),
obsIds,
maxDepth,
maxResults,
);
}

private async scoreExpansion(
linkedNodes: GraphNode[],
getNeighbors: NeighborProvider,
obsIds: string[],
maxDepth: number,
maxResults: number,
): Promise<GraphRetrievalResult[]> {
const orderedLinked = [...linkedNodes].sort((a, b) =>
a.id.localeCompare(b.id),
);

const results: GraphRetrievalResult[] = [];
const visitedObs = new Set<string>(obsIds);

for (const node of linkedNodes) {
const paths = this.dijkstraTraversal(node, allNodes, allEdges, maxDepth);
for (const node of orderedLinked) {
const paths = await this.dijkstraTraversal(node, getNeighbors, maxDepth);
for (const path of paths) {
const lastNode = path[path.length - 1].node;
for (const obsId of lastNode.sourceObservationIds) {
Expand Down Expand Up @@ -163,6 +282,25 @@ export class GraphRetrieval {
currentState: GraphEdge[];
history: GraphEdge[];
}> {
if (await graphIndexesReady(this.kv)) {
const reader = await GraphIndexReader.open(this.kv);
const catalog = await loadNameCatalog(this.kv);
const lower = entityName.toLowerCase();
let entity: GraphNode | null = null;
for (const entry of catalog) {
if (entry.name.toLowerCase() !== lower) continue;
const node = await reader.getNode(entry.id);
if (node) {
entity = node;
break;
}
}
if (!entity) return { entity: null, currentState: [], history: [] };

const relatedEdges = await reader.getIncidentEdges(entity.id);
return this.partitionTemporalEdges(entity, relatedEdges, asOf);
}

const allNodes = (await this.kv.list<GraphNode>(KV.graphNodes)).filter((n) => !n.stale);
const allEdges = (await this.kv.list<GraphEdge>(KV.graphEdges)).filter((e) => !e.stale);

Expand All @@ -175,6 +313,18 @@ export class GraphRetrieval {
(e) => e.sourceNodeId === entity.id || e.targetNodeId === entity.id,
);

return this.partitionTemporalEdges(entity, relatedEdges, asOf);
}

private partitionTemporalEdges(
entity: GraphNode,
relatedEdges: GraphEdge[],
asOf?: string,
): {
entity: GraphNode | null;
currentState: GraphEdge[];
history: GraphEdge[];
} {
if (!asOf) {
const latestEdges = this.getLatestEdges(relatedEdges);
const historicalEdges = relatedEdges.filter(
Expand Down Expand Up @@ -231,32 +381,16 @@ export class GraphRetrieval {
// which fell back to edge-count order and ignored the 0.1-1.0 weight
// attached to every graph edge. Dijkstra over `cost = 1/weight`
// (cheaper edges = stronger relationships) returns the
// highest-weighted path to each reachable node within maxDepth. Also
// tightens the perf profile:
// - Adjacency built once in O(V+E) (previous BFS re-filtered
// allEdges per visited node, O(V·E) overall).
// - Min-heap dequeue is O(log V) per pop (previous queue.shift()
// was O(n) — the dominant cost on graphs above ~200 nodes per
// the contributor's benchmark in #328).
private dijkstraTraversal(
// highest-weighted path to each reachable node within maxDepth.
// Neighbor expansion is delegated to the provider so the same
// traversal serves both the enumeration fallback (prebuilt adjacency
// over kv.list arrays) and the side-index path (targeted adjacency
// gets bounded by degree x maxDepth).
private async dijkstraTraversal(
startNode: GraphNode,
allNodes: GraphNode[],
allEdges: GraphEdge[],
getNeighbors: NeighborProvider,
maxDepth: number,
): Array<Array<{ node: GraphNode; edge?: GraphEdge }>> {
const nodeIndex = new Map<string, GraphNode>();
for (const n of allNodes) nodeIndex.set(n.id, n);

const adjacency = new Map<string, Array<{ neighborId: string; edge: GraphEdge }>>();
for (const edge of allEdges) {
const a = edge.sourceNodeId;
const b = edge.targetNodeId;
if (!adjacency.has(a)) adjacency.set(a, []);
if (!adjacency.has(b)) adjacency.set(b, []);
adjacency.get(a)!.push({ neighborId: b, edge });
adjacency.get(b)!.push({ neighborId: a, edge });
}

): Promise<Array<Array<{ node: GraphNode; edge?: GraphEdge }>>> {
const dist = new Map<string, number>();
const pathTo = new Map<string, Array<{ node: GraphNode; edge?: GraphEdge }>>();
dist.set(startNode.id, 0);
Expand All @@ -273,10 +407,9 @@ export class GraphRetrieval {
if (cost > (dist.get(nodeId) ?? Infinity)) continue;
if (depth >= maxDepth) continue;

const neighbors = adjacency.get(nodeId) ?? [];
for (const { neighborId, edge } of neighbors) {
const nextNode = nodeIndex.get(neighborId);
if (!nextNode) continue;
const neighbors = await getNeighbors(nodeId);
for (const { node: nextNode, edge } of neighbors) {
const neighborId = nextNode.id;
// Clamp weight to avoid division-by-zero on malformed edges;
// 0.01 is below the documented 0.1 floor.
const edgeCost = 1 / Math.max(edge.weight, 0.01);
Expand Down
Loading