diff --git a/.env.example b/.env.example index 77ca0f3a3..d2f3f4a8c 100644 --- a/.env.example +++ b/.env.example @@ -23,8 +23,9 @@ # Without a provider key, agentmemory runs in noop mode: observations are # indexed via zero-LLM synthetic compression, hybrid search still works, # but LLM-backed summarisation / reflection / consolidation are disabled. -# The detection order is OPENAI_API_KEY → MINIMAX_API_KEY → ANTHROPIC_API_KEY -# → GEMINI_API_KEY → OPENROUTER_API_KEY → noop. +# The detection order is AWS_BEDROCK → OPENAI_API_KEY → MINIMAX_API_KEY → +# ANTHROPIC_API_KEY → GEMINI_API_KEY → OPENROUTER_API_KEY → noop. Bedrock is +# first but only fires on the explicit AWS_BEDROCK=true opt-in flag. # OPENAI_API_KEY=sk-... # Used for OpenAI-compatible embeddings today. PR #307 will extend this to chat completions (DeepSeek, SiliconFlow, vLLM, LM Studio, Ollama via `/v1`). # OPENAI_BASE_URL=https://api.openai.com # Override for OpenAI-compatible providers @@ -43,6 +44,32 @@ # MINIMAX_API_KEY=... # MINIMAX_MODEL=MiniMax-M2.7 +# AWS Bedrock (Anthropic models on Bedrock). Opt in with AWS_BEDROCK=true; takes +# precedence over the keys above when set. Credentials come from the standard AWS +# provider chain — environment creds, IAM roles, or an SSO profile cached under +# ~/.aws/sso/cache/ (select with AWS_PROFILE). NOTE: agentmemory reads the cached +# SSO token but cannot perform the login — run `aws sso login --profile ` +# first, and re-run it when the session expires. +# AWS_BEDROCK=true +# AWS_REGION=us-east-1 # Required for Bedrock +# AWS_PROFILE=my-sso-profile # Optional; consumed by the AWS SDK directly +# AWS_BEDROCK_MODEL=anthropic.claude-haiku-4-5-20251001-v1:0 # Default: Claude Haiku 4.5 (bare on-demand ID) +# The bare ID above only works in Regions that offer the model on-demand AND +# where model access is enabled in the Bedrock console. In other Regions, use +# the geo-prefixed cross-region inference profile, e.g.: +# AWS_BEDROCK_MODEL=us.anthropic.claude-haiku-4-5-20251001-v1:0 (or eu.…) +# AWS_ACCESS_KEY_ID=... # Optional explicit static creds (CI escape hatch); both must be set +# AWS_SECRET_ACCESS_KEY=... # to take effect, else the provider chain is used +# Optional auth-refresh hook: when a Bedrock call fails with an expired-token +# error, agentmemory runs this command (no shell — argv split on whitespace, +# quotes honored) and retries once. Use it to re-establish an expired SSO +# session unattended. SECURITY: only the literal string below is ever executed; +# no model/memory data is interpolated. Note `aws sso login` is interactive +# (opens a browser) — in a headless daemon there is no approver, so the command +# is bounded by AWS_AUTH_REFRESH_TIMEOUT_MS. +# AWS_AUTH_REFRESH=aws sso login --profile my-sso-profile +# AWS_AUTH_REFRESH_TIMEOUT_MS=120000 # Default: 120 000 ms (2 min) + # MAX_TOKENS=4096 # Cap LLM completion tokens for compression / summarise calls # Outbound LLM / embedding timeout — shared across every raw-fetch provider @@ -67,7 +94,7 @@ # OPENAI_API_KEY → VOYAGE_API_KEY → COHERE_API_KEY → OPENROUTER_API_KEY → # local (Xenova/all-MiniLM-L6-v2, 384-dim). -# EMBEDDING_PROVIDER=local # local | openai | voyage | cohere | gemini | openrouter +# EMBEDDING_PROVIDER=local # local | openai | voyage | cohere | gemini | openrouter | bedrock # VOYAGE_API_KEY=pa-... # Optimised for code embeddings @@ -79,6 +106,22 @@ # OPENROUTER_EMBEDDING_MODEL=openai/text-embedding-3-small # When EMBEDDING_PROVIDER=openrouter +# AWS Bedrock embeddings (Cohere / Amazon Titan). Set EMBEDDING_PROVIDER=bedrock +# to use it — NOT auto-selected by AWS_BEDROCK=true (that opts into the Bedrock +# LLM only; embeddings stay on their current provider). Credentials come from the +# AWS provider chain (env / IAM role / SSO cache via AWS_PROFILE) — no key var — +# and the region is the shared AWS_REGION. +# AWS_BEDROCK_EMBEDDING_MODEL=cohere.embed-v4:0 # Default. Also: amazon.titan-embed-text-v2:0, cohere.embed-*-v3 +# Some models are INFERENCE_PROFILE-only in a given region (e.g. cohere.embed-v4:0 +# is not on-demand in us-east-2) and must use the geo-prefixed profile ID, e.g.: +# AWS_BEDROCK_EMBEDDING_MODEL=us.cohere.embed-v4:0 (or global.…). Titan v2 is +# on-demand and works with the bare ID. The us./eu./apac./global. prefix is +# stripped for model-family + known-dimensions detection. +# AWS_BEDROCK_EMBEDDING_DIMENSIONS=1024 # Default 1024. Cohere v4: 256/512/1024/1536; Titan v2: 256/512/1024. +# NOTE: the dimension is baked into the vector index — changing it later +# requires re-embedding all stored memories. Required for models not in the +# built-in known-dimensions table. + # ----------------------------------------------------------------------------- # 3. Auth & security # ----------------------------------------------------------------------------- diff --git a/README.md b/README.md index d5f4c414d..a1789ca17 100644 --- a/README.md +++ b/README.md @@ -884,6 +884,7 @@ npm install @xenova/transformers | Voyage AI | `voyage-code-3` | Paid | Optimized for code | | Cohere | `embed-english-v3.0` | Free trial | General purpose | | OpenRouter | Any model | Varies | Multi-model proxy | +| AWS Bedrock | `cohere.embed-v4:0` (default), `amazon.titan-embed-text-v2:0` | Paid (AWS) | Set `EMBEDDING_PROVIDER=bedrock`; creds via AWS chain / SSO; default 1024-dim. See [AWS Bedrock](#aws-bedrock). | --- @@ -1155,6 +1156,7 @@ agentmemory auto-detects from your environment. By default, no LLM calls are mad |----------|--------|-------| | **No-op (default)** | No config needed | LLM-backed compress/summarize is DISABLED. Synthetic BM25 compression + recall still work. See `AGENTMEMORY_ALLOW_AGENT_SDK` below if you used to rely on the Claude-subscription fallback. | | Anthropic API | `ANTHROPIC_API_KEY` | Per-token billing | +| AWS Bedrock | `AWS_BEDROCK=true` + `AWS_REGION` | Anthropic models on Bedrock. Opt-in flag, takes precedence when set. Creds from the AWS provider chain — env / IAM role / SSO cache (`AWS_PROFILE`). Default model Claude Haiku 4.5; see [AWS Bedrock](#aws-bedrock) below. | | MiniMax | `MINIMAX_API_KEY` | Anthropic-compatible | | Gemini | `GEMINI_API_KEY` | Also enables embeddings | | OpenRouter | `OPENROUTER_API_KEY` | Any model | @@ -1162,6 +1164,28 @@ agentmemory auto-detects from your environment. By default, no LLM calls are mad | **Local (Ollama / LM Studio / vLLM / llama.cpp)** | `OPENAI_API_KEY=local` + `OPENAI_BASE_URL=http://localhost:11434/v1` (Ollama) or `http://localhost:1234/v1` (LM Studio) + `OPENAI_MODEL=` | Anything OpenAI-API-compatible. Zero cost, runs on your hardware. See [Local models](#local-models-ollama-lm-studio-vllm) below. | | Claude subscription fallback | `AGENTMEMORY_ALLOW_AGENT_SDK=true` | Opt-in only. Spawns `@anthropic-ai/claude-agent-sdk` sessions — used to cause unbounded Stop-hook recursion (#149 follow-up) so it is no longer the default. | +### AWS Bedrock + +Run Anthropic models hosted on AWS Bedrock as the LLM provider. Opt in with `AWS_BEDROCK=true`; when set it takes precedence over the other provider keys. + +```bash +AWS_BEDROCK=true +AWS_REGION=us-east-1 +AWS_PROFILE=my-sso-profile # optional +AWS_BEDROCK_MODEL=anthropic.claude-haiku-4-5-20251001-v1:0 # optional; this is the default +``` + +- **Credentials** come from the standard AWS credential provider chain — environment credentials, IAM roles, or an SSO profile cached under `~/.aws/sso/cache/` (select the profile with `AWS_PROFILE`). No static keys are required. To force static keys (e.g. in CI), set **both** `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. +- **SSO** works out of the box, but agentmemory only *reads* the cached token — it cannot perform the login. Run `aws sso login --profile ` first. When the session expires, either re-run it manually or configure the auth-refresh hook (below) to automate re-authentication. +- **Auth-refresh hook** (optional): when a Bedrock call fails with an expired-token error, agentmemory can run a command of your choosing and retry once: + ```bash + AWS_AUTH_REFRESH=aws sso login --profile my-sso-profile + AWS_AUTH_REFRESH_TIMEOUT_MS=120000 # optional, default 2 min + ``` + The command is single-flighted (concurrent calls trigger it once), rate-limited by a short cooldown, and bounded by the timeout. **Security:** only the literal configured string is executed — via `execFile`, no shell, and no model or memory data is ever interpolated into it. Note that `aws sso login` is interactive (opens a browser), so this is best suited to setups where someone can approve the login or where the configured command refreshes credentials non-interactively. +- **Model ID** defaults to Claude Haiku 4.5 (`anthropic.claude-haiku-4-5-20251001-v1:0`) — fast and cost-efficient for background compression. The bare on-demand ID only works in Regions that offer the model on-demand and where model access is enabled in the Bedrock console. In other Regions, set `AWS_BEDROCK_MODEL` to the geo-prefixed cross-region inference profile, e.g. `us.anthropic.claude-haiku-4-5-20251001-v1:0` (or `eu.…`). +- **Embeddings on Bedrock** (separate from the LLM): set `EMBEDDING_PROVIDER=bedrock` to use Cohere / Titan embeddings via the same AWS credentials. It is *not* auto-enabled by `AWS_BEDROCK=true` — so you can run the Bedrock LLM with local (or any other) embeddings. Defaults to `cohere.embed-v4:0` at 1024 dims; override with `AWS_BEDROCK_EMBEDDING_MODEL` / `AWS_BEDROCK_EMBEDDING_DIMENSIONS`. As with the LLM, some embedding models aren't available on-demand in every Region — `cohere.embed-v4:0` is inference-profile-only in several Regions, so set the geo-prefixed ID there, e.g. `AWS_BEDROCK_EMBEDDING_MODEL=us.cohere.embed-v4:0` (Titan v2 works on-demand with the bare ID). The dimension is baked into the vector index, so changing it later means re-embedding stored memories. + ### Local models (Ollama / LM Studio / vLLM) agentmemory talks to any OpenAI-API-compatible server, so anything that exposes `/v1/chat/completions` works without code changes. No paid keys, no cloud, no rate limits — runs entirely on your hardware. diff --git a/package.json b/package.json index bbc88d70d..caef0fc71 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ }, "scripts": { "build": "tsdown && (cp iii-config.yaml dist/ 2>/dev/null || true) && (cp iii-config.docker.yaml dist/ 2>/dev/null || true) && (cp docker-compose.yml dist/ 2>/dev/null || true) && (cp .env.example dist/ 2>/dev/null || true) && mkdir -p dist/viewer && cp src/viewer/index.html dist/viewer/ && cp src/viewer/favicon.svg dist/viewer/", + "prepare": "npm run build", "dev": "tsx src/index.ts", "start": "node dist/cli.mjs", "migrate": "node dist/functions/migrate.js", @@ -58,8 +59,10 @@ "url": "https://github.com/rohitg00/agentmemory" }, "dependencies": { + "@anthropic-ai/bedrock-sdk": "^0.29.2", "@anthropic-ai/claude-agent-sdk": "^0.3.142", "@anthropic-ai/sdk": "^0.100.1", + "@aws-sdk/client-bedrock-runtime": "^3.1057.0", "@clack/prompts": "^1.2.0", "dotenv": "^17.4.2", "iii-sdk": "0.11.2", diff --git a/src/config.ts b/src/config.ts index b3f8882d2..92b167763 100644 --- a/src/config.ts +++ b/src/config.ts @@ -49,9 +49,46 @@ function hasRealValue(v: string | undefined): v is string { return typeof v === "string" && v.trim().length > 0; } -function detectProvider(env: Record): ProviderConfig { +/** Prevents AWS_BEDROCK=True / TRUE from silently disabling Bedrock. */ +function isEnvTrue(v: string | undefined): boolean { + return typeof v === "string" && v.trim().toLowerCase() === "true"; +} + +/** Shared so detectProvider and isBedrockUsable gate on the same opt-in values. */ +function isBedrockOptIn(env: Record): boolean { + return isEnvTrue(env["AWS_BEDROCK"]); +} + +/** A region is required to construct the client, so capability detection never reports an unbuildable config. */ +function isBedrockUsable(env: Record): boolean { + return isBedrockOptIn(env) && hasRealValue(env["AWS_REGION"]); +} + +export function detectProvider(env: Record): ProviderConfig { const maxTokens = parseInt(env["MAX_TOKENS"] || "4096", 10); + // AWS Bedrock: explicit opt-in via AWS_BEDROCK=true. Placed first so a machine + // with both Ollama and Bedrock configured prefers Bedrock when opted in; the + // strict flag gate means it never fires for existing OpenAI/Ollama users. + // Credentials come from the AWS provider chain (env / IAM role / SSO cache), + // so we do NOT key detection on credential env vars — only the flag + region. + // Region is mandatory: without it Bedrock cannot be constructed, so we reject + // here and fall through rather than returning an unusable bedrock config. + if (isBedrockOptIn(env)) { + if (isBedrockUsable(env)) { + return { + provider: "bedrock", + model: env["AWS_BEDROCK_MODEL"] || "anthropic.claude-haiku-4-5-20251001-v1:0", + maxTokens, + }; + } + process.stderr.write( + "[agentmemory] AWS_BEDROCK=true but AWS_REGION is unset — ignoring Bedrock " + + "and falling through to the next provider. Set AWS_REGION in " + + "~/.agentmemory/.env to enable Bedrock.\n", + ); + } + // OpenAI-compatible: supports OpenAI, DeepSeek, SiliconFlow, Azure, vLLM, LM Studio if (hasRealValue(env["OPENAI_API_KEY"]) && env["OPENAI_API_KEY_FOR_LLM"] !== "false") { return { @@ -191,6 +228,7 @@ export function isDropStaleIndexEnabled(): boolean { export function detectLlmProviderKind(): "llm" | "noop" { const env = getMergedEnv(); if ( + isBedrockUsable(env) || hasRealValue(env["ANTHROPIC_API_KEY"]) || hasRealValue(env["GEMINI_API_KEY"]) || hasRealValue(env["GOOGLE_API_KEY"]) || @@ -403,6 +441,7 @@ export function getStandalonePersistPath(): string { const VALID_PROVIDERS = new Set([ "anthropic", + "bedrock", "gemini", "openrouter", "agent-sdk", diff --git a/src/functions/compress-file.ts b/src/functions/compress-file.ts index 0a54452b2..b2e816236 100644 --- a/src/functions/compress-file.ts +++ b/src/functions/compress-file.ts @@ -5,6 +5,7 @@ import type { ISdk } from "iii-sdk"; import type { MemoryProvider } from "../types.js"; import type { StateKV } from "../state/kv.js"; import { recordAudit } from "./audit.js"; +import { logger } from "../logger.js"; const SENSITIVE_PATH_TERMS = [ "secret", @@ -133,10 +134,24 @@ export function registerCompressFileFunction( return { success: true, skipped: true, reason: "file is empty" }; } - const response = await provider.summarize( - COMPRESS_FILE_SYSTEM_PROMPT, - `Compress this markdown file while preserving structure and code blocks:\n\n${original}`, - ); + let response: string; + try { + response = await provider.summarize( + COMPRESS_FILE_SYSTEM_PROMPT, + `Compress this markdown file while preserving structure and code blocks:\n\n${original}`, + ); + } catch (err) { + // Surface the provider's message as a structured error. Without this the + // throw escapes the function and the engine serializes it as the opaque + // "[object Object]", hiding actionable hints (e.g. the Bedrock provider's + // model-access / inference-profile guidance). + const msg = err instanceof Error ? err.message : String(err); + logger.error("compress-file provider call failed", { + filePath: absolutePath, + error: msg, + }); + return { success: false, error: msg }; + } const compressed = stripMarkdownFence(response); const validationErrors = validateCompression(original, compressed); if (validationErrors.length > 0) { diff --git a/src/providers/auth-refresh.ts b/src/providers/auth-refresh.ts new file mode 100644 index 000000000..24b436926 --- /dev/null +++ b/src/providers/auth-refresh.ts @@ -0,0 +1,175 @@ +import { execFile } from "node:child_process"; +import { logger } from "../logger.js"; + +/** + * Conservative classifier for "credentials/token expired" errors from Bedrock + * or the underlying AWS STS / SSO layer. Kept to a narrow allow-list so that + * genuine errors (bad request, throttling, model-access denials) are NOT + * mistaken for an expiry and do not trigger a refresh. + */ +export function isAuthExpiry(err: unknown): boolean { + const name = (err as { name?: string })?.name ?? ""; + const code = (err as { code?: string })?.code ?? ""; + const message = err instanceof Error ? err.message : String(err ?? ""); + const haystack = `${name} ${code} ${message}`; + return ( + // STS / signed-request side: the token literally "expired". + /ExpiredToken|ExpiredTokenException|(?:security )?token (?:included in the request )?(?:is |has )?expired|credentials? (?:have )?expired/i.test( + haystack, + ) || + // SSO-cache side: the cached session token may be reported as expired OR + // (after `aws sso logout` / first run) "not found or is invalid" — the word + // "expired" never appears. Match an SSO-session phrase paired with any of + // those states, bounded so it can't run away across the whole message. + /SSO session[\w\s=.,'"-]*?(?:has expired|not found|is invalid|invalid|expired)/i.test( + haystack, + ) || + // AWS's own remediation hint: when it tells you to re-run `aws sso login`, + // the situation is by definition a credential refresh. Strong, version- + // stable signal that complements the message-state matching above. + /\baws sso login\b/i.test(haystack) + ); +} + +/** + * Parse a configured command string into argv WITHOUT a shell. Supports simple + * single/double quoting so `--profile "my profile"` works; intentionally does + * NOT support shell features (pipes, expansion, substitution) — the command is + * run via execFile, not a shell, which is the trust boundary. + */ +export function tokenizeCommand(command: string): string[] { + const tokens: string[] = []; + const re = /"([^"]*)"|'([^']*)'|(\S+)/g; + let m: RegExpExecArray | null; + while ((m = re.exec(command)) !== null) { + tokens.push(m[1] ?? m[2] ?? m[3]); + } + return tokens; +} + +export interface AuthRefreshOptions { + /** Full command string, e.g. `aws sso login --profile my-sso-profile`. */ + command: string; + /** Hard timeout for the spawned command (ms). */ + timeoutMs?: number; + /** Minimum interval between refresh attempts (ms) — prevents login storms. */ + cooldownMs?: number; + /** + * Suppression window (ms) applied AFTER a timed-out attempt. A timeout means an + * interactive login (e.g. a browser device-auth page) was almost certainly left + * open awaiting approval; re-running would stack up more stale login pages. We + * back off for much longer than the ordinary cooldown so the user has time to + * complete (or abandon) the pending login. Default: 15 min. + */ + postTimeoutCooldownMs?: number; +} + +/** + * Runs a user-configured credential-refresh command (e.g. `aws sso login`) when + * a provider call fails with an expired-token error. Equivalent in spirit to + * Claude Code's `awsAuthRefresh` setting. + * + * Safeguards: + * - Single-flight: concurrent callers share one in-flight run. + * - Cooldown: refuses to re-run within `cooldownMs` of the last attempt. + * - Post-timeout backoff: after a timeout, suppresses re-runs for + * `postTimeoutCooldownMs` so a hung interactive login isn't relaunched on + * every background trigger (which would fill the browser with stale pages). + * - Timeout: the spawned command is killed after `timeoutMs`. + * - No shell: the command is tokenized and executed via execFile, and only the + * configured string is ever run — no untrusted data is interpolated. + */ +export class AuthRefresh { + private readonly argv: string[]; + private readonly timeoutMs: number; + private readonly cooldownMs: number; + private readonly postTimeoutCooldownMs: number; + private inFlight: Promise | null = null; + private lastAttemptAt: number | null = null; + /** Set when the previous attempt timed out — gates re-runs for longer. */ + private suppressedUntil: number | null = null; + + constructor(opts: AuthRefreshOptions) { + this.argv = tokenizeCommand(opts.command); + this.timeoutMs = opts.timeoutMs ?? 120_000; + this.cooldownMs = opts.cooldownMs ?? 10_000; + this.postTimeoutCooldownMs = opts.postTimeoutCooldownMs ?? 900_000; + } + + /** + * Run the refresh command. Single-flight + cooldown guarded. Resolves when the + * command exits 0; rejects on non-zero exit, timeout, empty command, or while a + * post-timeout suppression window is active. + */ + async run(): Promise { + if (this.inFlight) return this.inFlight; + + const now = Date.now(); + + // Post-timeout backoff: a prior attempt timed out, so an interactive login is + // likely still pending. Don't launch another until the window elapses. + if (this.suppressedUntil !== null && now < this.suppressedUntil) { + const waitMs = this.suppressedUntil - now; + logger.warn("auth refresh suppressed after a prior timeout", { + command: this.argv[0], + retryInMs: waitMs, + }); + throw new Error( + `auth refresh suppressed: a previous attempt timed out; not retrying for ` + + `another ${waitMs}ms (a pending interactive login may still be open)`, + ); + } + + if (this.lastAttemptAt !== null && now - this.lastAttemptAt < this.cooldownMs) { + logger.info("auth refresh skipped (cooldown)", { + sinceLastMs: now - this.lastAttemptAt, + cooldownMs: this.cooldownMs, + }); + throw new Error( + `auth refresh skipped: last attempt was ${now - this.lastAttemptAt}ms ago ` + + `(cooldown ${this.cooldownMs}ms)`, + ); + } + this.lastAttemptAt = now; + + if (this.argv.length === 0) { + throw new Error("auth refresh command is empty"); + } + + const [cmd, ...args] = this.argv; + logger.info("auth refresh: running credential command", { command: cmd }); + this.inFlight = new Promise((resolve, reject) => { + execFile(cmd, args, { timeout: this.timeoutMs }, (err) => { + if (err) { + // execFile flags a timeout kill via `killed` + the configured signal. + const timedOut = + (err as { killed?: boolean }).killed === true || + (err as { signal?: string }).signal === "SIGTERM"; + if (timedOut) { + this.suppressedUntil = Date.now() + this.postTimeoutCooldownMs; + logger.error("auth refresh command timed out", { + command: cmd, + timeoutMs: this.timeoutMs, + suppressForMs: this.postTimeoutCooldownMs, + }); + } else { + logger.error("auth refresh command failed", { + command: cmd, + error: err.message, + }); + } + reject(err); + } else { + logger.info("auth refresh: credential command succeeded", { + command: cmd, + }); + resolve(); + } + }); + }).finally(() => { + this.inFlight = null; + }); + + return this.inFlight; + } +} diff --git a/src/providers/bedrock.ts b/src/providers/bedrock.ts new file mode 100644 index 000000000..50b199fb2 --- /dev/null +++ b/src/providers/bedrock.ts @@ -0,0 +1,133 @@ +import { AnthropicBedrock } from '@anthropic-ai/bedrock-sdk' +import type { MemoryProvider } from '../types.js' +import { getEnvVar } from '../config.js' + +/** + * AWS Bedrock LLM provider (Anthropic models on Bedrock). + * + * Wraps `@anthropic-ai/bedrock-sdk`, which speaks the same + * `messages.create(...)` surface as the first-party Anthropic SDK but + * authenticates with AWS SigV4 instead of an `x-api-key` header. + * + * Credentials: by default NO explicit keys are passed, so the AWS SDK v3 + * default credential provider chain resolves them — environment creds, IAM + * roles, and crucially **SSO profiles** cached under `~/.aws/sso/cache/` + * (select with `AWS_PROFILE`). The SDK reads a cached SSO token; it cannot + * perform the interactive `aws sso login` itself, so the session must already + * be valid. Static keys (`AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`) are an + * opt-in escape hatch for CI. + * + * Required env: + * AWS_REGION — Bedrock region (also consumed by the SDK directly). + * + * Optional: + * AWS_BEDROCK_MODEL — model / inference-profile ID (default below). + * AWS_PROFILE — SSO/credentials profile, consumed by the AWS SDK. + * AWS_ACCESS_KEY_ID — explicit static key (escape hatch / CI). + * AWS_SECRET_ACCESS_KEY — explicit static secret (escape hatch / CI). + * AWS_SESSION_TOKEN — explicit session token for temporary creds. + * + * Model IDs are Bedrock-style (e.g. `anthropic.claude-haiku-4-5-20251001-v1:0`), + * NOT the bare Anthropic model name. In Regions where the model is not offered + * on-demand it is reachable only via a cross-region inference profile, whose ID + * is geo-prefixed: `us.anthropic.claude-haiku-4-5-20251001-v1:0` (or `eu.`). + */ +export class BedrockProvider implements MemoryProvider { + name = 'bedrock' + private client: AnthropicBedrock + private model: string + private maxTokens: number + + constructor(model: string, maxTokens: number, awsRegion: string) { + const awsAccessKey = getEnvVar('AWS_ACCESS_KEY_ID') + const awsSecretKey = getEnvVar('AWS_SECRET_ACCESS_KEY') + const awsSessionToken = getEnvVar('AWS_SESSION_TOKEN') + + // Only pass explicit keys when BOTH are present — otherwise omit them so the + // AWS credential provider chain (env / IAM role / SSO cache) resolves creds. + this.client = + awsAccessKey && awsSecretKey + ? new AnthropicBedrock({ + awsRegion, + awsAccessKey, + awsSecretKey, + ...(awsSessionToken ? { awsSessionToken } : {}), + }) + : new AnthropicBedrock({ awsRegion }) + this.model = model + this.maxTokens = maxTokens + } + + async compress(systemPrompt: string, userPrompt: string): Promise { + return this.call(systemPrompt, userPrompt) + } + + async summarize(systemPrompt: string, userPrompt: string): Promise { + return this.call(systemPrompt, userPrompt) + } + + async describeImage(imageData: string, mimeType: string, prompt: string): Promise { + try { + const response = await this.client.messages.create({ + model: this.model, + max_tokens: this.maxTokens, + messages: [{ + role: 'user', + content: [ + { + type: 'image', + source: { type: 'base64', media_type: mimeType as 'image/png' | 'image/jpeg' | 'image/gif' | 'image/webp', data: imageData }, + }, + { type: 'text', text: prompt }, + ], + }], + }) + + const textBlock = response.content.find((b) => b.type === 'text') + return textBlock?.text ?? '' + } catch (err) { + throw this.explainError(err) + } + } + + private async call(systemPrompt: string, userPrompt: string): Promise { + try { + const response = await this.client.messages.create({ + model: this.model, + max_tokens: this.maxTokens, + system: systemPrompt, + messages: [{ role: 'user', content: userPrompt }], + }) + + const textBlock = response.content.find((b) => b.type === 'text') + return textBlock?.text ?? '' + } catch (err) { + throw this.explainError(err) + } + } + + /** + * Turn an opaque Bedrock model-access / validation 4xx into an actionable + * error. The bare on-demand model ID only works in Regions that offer the + * model on-demand; elsewhere callers must enable model access or switch to a + * `us.`/`eu.`-prefixed cross-region inference profile. + */ + private explainError(err: unknown): unknown { + const status = (err as { status?: number })?.status + const message = err instanceof Error ? err.message : String(err) + if ( + status === 403 || + status === 400 || + /access|not authorized|inference profile|on-demand|ValidationException|AccessDenied/i.test(message) + ) { + return new Error( + `Bedrock model "${this.model}" could not be invoked (${message}). ` + + `Check that: (1) model access is enabled for this account in the Bedrock console, ` + + `(2) AWS_REGION (${this.client.awsRegion}) offers this model, and ` + + `(3) for Regions without on-demand access, AWS_BEDROCK_MODEL is set to the ` + + `"us."/"eu."-prefixed cross-region inference profile ID.`, + ) + } + return err + } +} diff --git a/src/providers/embedding/bedrock.ts b/src/providers/embedding/bedrock.ts new file mode 100644 index 000000000..c0fb3e767 --- /dev/null +++ b/src/providers/embedding/bedrock.ts @@ -0,0 +1,248 @@ +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; +import type { EmbeddingProvider } from "../../types.js"; +import { getEnvVar } from "../../config.js"; + +const DEFAULT_MODEL = "cohere.embed-v4:0"; + +/** + * Known embedding dimensions by Bedrock model ID. Override in any case via + * AWS_BEDROCK_EMBEDDING_DIMENSIONS. Models not listed here REQUIRE that override + * — we refuse to guess, because a wrong dimension silently corrupts the vector + * index (see withDimensionGuard). + * + * Cohere v4 + Titan v2 are Matryoshka models (selectable output dims); the + * default of 1024 is sent in the request body, not just reported. + */ +const MODEL_DIMENSIONS: Record = { + "cohere.embed-v4:0": 1024, + "cohere.embed-english-v3": 1024, + "cohere.embed-multilingual-v3": 1024, + "amazon.titan-embed-text-v2:0": 1024, + "amazon.titan-embed-text-v1": 1536, +}; + +// Titan has no native batch endpoint — embedBatch fans out one InvokeModel call +// per input. Bound the in-flight count to stay within Bedrock rate limits while +// keeping throughput reasonable (mirrors summarize.ts's chunk concurrency). +const TITAN_BATCH_CONCURRENCY = 6; + +// Cohere caps texts at 96 per InvokeModel call. +const COHERE_MAX_BATCH = 96; + +/** + * Strip a leading cross-region inference-profile geo prefix (`us.`, `eu.`, + * `apac.`, `global.`) so model-family detection and the known-dimensions lookup + * work against the underlying model ID. Bedrock requires the prefixed profile ID + * for models that don't support on-demand throughput (e.g. cohere.embed-v4:0 in + * us-east-2 → us.cohere.embed-v4:0), but the family/dims are the same model. + */ +function stripInferenceProfilePrefix(model: string): string { + return model.replace(/^(?:us|eu|apac|global)\./, ""); +} + +function resolveDimensions(model: string, override: string | undefined): number { + if (override !== undefined && override.trim().length > 0) { + const parsed = parseInt(override, 10); + if (!Number.isFinite(parsed) || parsed <= 0) { + throw new Error( + `AWS_BEDROCK_EMBEDDING_DIMENSIONS must be a positive integer, got: ${override}`, + ); + } + return parsed; + } + const known = MODEL_DIMENSIONS[stripInferenceProfilePrefix(model)]; + if (known === undefined) { + throw new Error( + `Unknown Bedrock embedding model "${model}" — set AWS_BEDROCK_EMBEDDING_DIMENSIONS ` + + `to its output dimension (a wrong value silently corrupts the vector index).`, + ); + } + return known; +} + +type ModelFamily = "cohere" | "titan"; + +function familyOf(model: string): ModelFamily { + const base = stripInferenceProfilePrefix(model); + if (base.startsWith("cohere.")) return "cohere"; + if (base.startsWith("amazon.titan-embed")) return "titan"; + throw new Error( + `Unsupported Bedrock embedding model "${model}" — expected a "cohere." or ` + + `"amazon.titan-embed" model ID (optionally with a us./eu./apac./global. ` + + `inference-profile prefix).`, + ); +} + +/** + * AWS Bedrock embedding provider (Cohere / Amazon Titan embeddings on Bedrock). + * + * Uses the AWS Bedrock Runtime InvokeModel API (not the Anthropic bedrock-sdk, + * which has no embeddings). Credentials resolve via the AWS default provider + * chain — env / IAM role / SSO cache (select with AWS_PROFILE) — exactly like + * the Bedrock LLM provider, so no key env var is needed. Static keys are honored + * only when both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set. + * + * Required env: + * AWS_REGION — Bedrock region (shared with the LLM provider). + * + * Optional: + * AWS_BEDROCK_EMBEDDING_MODEL — model ID (default: cohere.embed-v4:0). + * AWS_BEDROCK_EMBEDDING_DIMENSIONS — output dims (default 1024; required for + * models not in the known-dims table). + * AWS_PROFILE / AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN + * — same credential knobs as the LLM provider. + */ +export class BedrockEmbeddingProvider implements EmbeddingProvider { + readonly name = "bedrock"; + readonly dimensions: number; + private client: BedrockRuntimeClient; + private model: string; + private family: ModelFamily; + + constructor() { + const region = getEnvVar("AWS_REGION"); + if (!region) { + throw new Error("AWS_REGION is required for the bedrock embedding provider"); + } + this.model = getEnvVar("AWS_BEDROCK_EMBEDDING_MODEL") || DEFAULT_MODEL; + this.family = familyOf(this.model); + this.dimensions = resolveDimensions( + this.model, + getEnvVar("AWS_BEDROCK_EMBEDDING_DIMENSIONS"), + ); + + const accessKeyId = getEnvVar("AWS_ACCESS_KEY_ID"); + const secretAccessKey = getEnvVar("AWS_SECRET_ACCESS_KEY"); + const sessionToken = getEnvVar("AWS_SESSION_TOKEN"); + // Pass explicit creds only when both halves are present; otherwise omit so + // the AWS provider chain (env / IAM role / SSO cache) resolves them. + this.client = + accessKeyId && secretAccessKey + ? new BedrockRuntimeClient({ + region, + credentials: { + accessKeyId, + secretAccessKey, + ...(sessionToken ? { sessionToken } : {}), + }, + }) + : new BedrockRuntimeClient({ region }); + } + + async embed(text: string): Promise { + const [result] = await this.embedBatch([text]); + return result; + } + + async embedBatch(texts: string[]): Promise { + return this.family === "cohere" + ? this.embedCohere(texts) + : this.embedTitan(texts); + } + + // Cohere: native batch, up to 96 texts per call. Request a single float + // embedding type, which yields the keyed-by-type response shape + // { embeddings: { float: [[...]] } }. + private async embedCohere(texts: string[]): Promise { + const out: Float32Array[] = []; + for (let i = 0; i < texts.length; i += COHERE_MAX_BATCH) { + const slice = texts.slice(i, i + COHERE_MAX_BATCH); + const body: Record = { + input_type: "search_document", + texts: slice, + embedding_types: ["float"], + }; + // Only Cohere v4 accepts output_dimension; v3 is fixed at 1024. + if (this.model.includes("embed-v4")) body.output_dimension = this.dimensions; + + const json = await this.invoke(body); + // v4 (embedding_types specified) → { embeddings: { float: [[...]] } }. + // v3 → { embeddings: [[...]] }. + const embeddings = + (json.embeddings as { float?: number[][] } | number[][] | undefined) ?? []; + const rows = Array.isArray(embeddings) + ? (embeddings as number[][]) + : (embeddings.float ?? []); + // Fail fast on a cardinality mismatch: fewer rows than inputs would + // silently misalign texts to vectors downstream (withDimensionGuard only + // checks each vector's length, not the batch count). + if (rows.length !== slice.length) { + throw new Error( + `Bedrock embedding returned ${rows.length} vectors for ${slice.length} inputs ` + + `(model "${this.model}") — refusing to misalign texts to vectors.`, + ); + } + for (const row of rows) out.push(new Float32Array(row)); + } + return out; + } + + // Titan: one input per call, no batch endpoint — fan out with bounded concurrency. + private async embedTitan(texts: string[]): Promise { + const results: Float32Array[] = new Array(texts.length); + let next = 0; + const worker = async (): Promise => { + while (next < texts.length) { + const idx = next++; + const json = await this.invoke({ + inputText: texts[idx], + dimensions: this.dimensions, + normalize: true, + }); + results[idx] = new Float32Array((json.embedding as number[]) ?? []); + } + }; + const workers = Array.from( + { length: Math.min(TITAN_BATCH_CONCURRENCY, texts.length) }, + () => worker(), + ); + await Promise.all(workers); + return results; + } + + private async invoke(body: Record): Promise> { + try { + const response = await this.client.send( + new InvokeModelCommand({ + modelId: this.model, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify(body), + }), + ); + const text = new TextDecoder().decode(response.body); + return JSON.parse(text) as Record; + } catch (err) { + throw this.explainError(err); + } + } + + /** + * Turn an opaque Bedrock model-access / validation 4xx into an actionable + * error, mirroring the LLM provider's guidance. + */ + private explainError(err: unknown): unknown { + const status = + (err as { $metadata?: { httpStatusCode?: number } })?.$metadata + ?.httpStatusCode; + const message = err instanceof Error ? err.message : String(err); + if ( + status === 403 || + status === 400 || + /access|not authorized|inference profile|on-demand|ValidationException|AccessDenied/i.test( + message, + ) + ) { + return new Error( + `Bedrock embedding model "${this.model}" could not be invoked (${message}). ` + + `Check that: (1) model access is enabled for this account in the Bedrock console, ` + + `(2) AWS_REGION offers this embedding model, and ` + + `(3) AWS_BEDROCK_EMBEDDING_MODEL is a valid Bedrock embedding model ID.`, + ); + } + return err; + } +} diff --git a/src/providers/embedding/index.ts b/src/providers/embedding/index.ts index d18de2328..67f6b5633 100644 --- a/src/providers/embedding/index.ts +++ b/src/providers/embedding/index.ts @@ -5,6 +5,7 @@ import { OpenAIEmbeddingProvider } from "./openai.js"; import { VoyageEmbeddingProvider } from "./voyage.js"; import { CohereEmbeddingProvider } from "./cohere.js"; import { OpenRouterEmbeddingProvider } from "./openrouter.js"; +import { BedrockEmbeddingProvider } from "./bedrock.js"; import { LocalEmbeddingProvider } from "./local.js"; import { ClipEmbeddingProvider } from "./clip.js"; @@ -14,6 +15,7 @@ export { VoyageEmbeddingProvider, CohereEmbeddingProvider, OpenRouterEmbeddingProvider, + BedrockEmbeddingProvider, LocalEmbeddingProvider, ClipEmbeddingProvider, }; @@ -42,6 +44,8 @@ export function createEmbeddingProvider(): EmbeddingProvider | null { return withDimensionGuard(new CohereEmbeddingProvider(getEnvVar("COHERE_API_KEY")!)); case "openrouter": return withDimensionGuard(new OpenRouterEmbeddingProvider(getEnvVar("OPENROUTER_API_KEY")!)); + case "bedrock": + return withDimensionGuard(new BedrockEmbeddingProvider()); case "local": return withDimensionGuard(new LocalEmbeddingProvider()); default: diff --git a/src/providers/index.ts b/src/providers/index.ts index 0ec3feba0..101607adf 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,16 +1,19 @@ import type { MemoryProvider, ProviderConfig, + ProviderType, FallbackConfig, } from "../types.js"; import { AgentSDKProvider } from "./agent-sdk.js"; import { AnthropicProvider } from "./anthropic.js"; +import { BedrockProvider } from "./bedrock.js"; import { MinimaxProvider } from "./minimax.js"; import { NoopProvider } from "./noop.js"; import { OpenAIProvider } from "./openai.js"; import { OpenRouterProvider } from "./openrouter.js"; import { ResilientProvider } from "./resilient.js"; import { FallbackChainProvider } from "./fallback-chain.js"; +import { AuthRefresh } from "./auth-refresh.js"; import { getEnvVar } from "../config.js"; export { createEmbeddingProvider, createImageEmbeddingProvider } from "./embedding/index.js"; @@ -25,6 +28,24 @@ function requireEnvVar(key: string): string { return value; } +/** + * Build the optional credential-refresh hook. Only the bedrock provider uses it + * today, and only when AWS_AUTH_REFRESH is set; the mechanism itself is generic. + * Accepts every provider type that may be invoked (primary + fallback chain) so + * a bedrock provider reachable only via the fallback path still gets the hook. + */ +function createAuthRefresh(providerTypes: ProviderType[]): AuthRefresh | undefined { + if (!providerTypes.includes("bedrock")) return undefined; + const command = getEnvVar("AWS_AUTH_REFRESH"); + if (!command || !command.trim()) return undefined; + const timeoutRaw = getEnvVar("AWS_AUTH_REFRESH_TIMEOUT_MS"); + const timeoutMs = timeoutRaw ? parseInt(timeoutRaw, 10) : undefined; + return new AuthRefresh({ + command, + timeoutMs: Number.isFinite(timeoutMs) ? timeoutMs : undefined, + }); +} + // #778: fallback providers used to inherit the primary provider's // model name (e.g. fallback Gemini was called with `gpt-4o-mini`), // 404'd every call, and tripped the circuit breaker — making @@ -48,6 +69,8 @@ function defaultModelFor(providerType: ProviderConfig["provider"]): string { return getEnvVar("MINIMAX_MODEL") || "MiniMax-M2.7"; case "agent-sdk": return "claude-sonnet-4-20250514"; + case "bedrock": + return getEnvVar("AWS_BEDROCK_MODEL") || "claude-sonnet-4-20250514"; case "noop": default: return "noop"; @@ -55,7 +78,10 @@ function defaultModelFor(providerType: ProviderConfig["provider"]): string { } export function createProvider(config: ProviderConfig): ResilientProvider { - return new ResilientProvider(createBaseProvider(config)); + return new ResilientProvider( + createBaseProvider(config), + createAuthRefresh([config.provider]), + ); } export function createFallbackProvider( @@ -67,6 +93,7 @@ export function createFallbackProvider( } const providers: MemoryProvider[] = [createBaseProvider(config)]; + const builtTypes: ProviderType[] = [config.provider]; for (const providerType of fallbackConfig.providers) { if (providerType === config.provider) continue; try { @@ -81,15 +108,23 @@ export function createFallbackProvider( maxTokens: config.maxTokens, }; providers.push(createBaseProvider(fbConfig)); + builtTypes.push(providerType); } catch { // skip unavailable fallback providers } } + // Derive the refresh hook from every provider actually built (primary + + // fallbacks), so a bedrock provider reachable only via the fallback chain + // still refreshes expired credentials. + const authRefresh = createAuthRefresh(builtTypes); if (providers.length > 1) { - return new ResilientProvider(new FallbackChainProvider(providers)); + return new ResilientProvider( + new FallbackChainProvider(providers), + authRefresh, + ); } - return new ResilientProvider(providers[0]); + return new ResilientProvider(providers[0], authRefresh); } function createBaseProvider(config: ProviderConfig): MemoryProvider { @@ -107,6 +142,15 @@ function createBaseProvider(config: ProviderConfig): MemoryProvider { config.maxTokens, config.baseURL, ); + case "bedrock": + // No requireEnvVar for a key: creds may come from the AWS credential + // provider chain (SSO cache / IAM role) with no env var set. A region is + // mandatory for Bedrock, though. + return new BedrockProvider( + config.model, + config.maxTokens, + requireEnvVar("AWS_REGION"), + ); case "gemini": { const geminiKey = getEnvVar("GEMINI_API_KEY") || getEnvVar("GOOGLE_API_KEY"); diff --git a/src/providers/resilient.ts b/src/providers/resilient.ts index 95ece40c9..6ca37bbfd 100644 --- a/src/providers/resilient.ts +++ b/src/providers/resilient.ts @@ -1,15 +1,23 @@ import type { MemoryProvider, CircuitBreakerState } from "../types.js"; import { CircuitBreaker } from "./circuit-breaker.js"; +import { AuthRefresh, isAuthExpiry } from "./auth-refresh.js"; +import { logger } from "../logger.js"; export class ResilientProvider implements MemoryProvider { private breaker = new CircuitBreaker(); name: string; - constructor(private inner: MemoryProvider) { + constructor( + private inner: MemoryProvider, + private authRefresh?: AuthRefresh, + ) { this.name = `resilient(${inner.name})`; } - private async call(fn: () => Promise): Promise { + private async call( + fn: () => Promise, + alreadyRetried = false, + ): Promise { if (!this.breaker.isAllowed) { throw new Error("circuit_breaker_open"); } @@ -18,6 +26,37 @@ export class ResilientProvider implements MemoryProvider { this.breaker.recordSuccess(); return result; } catch (err) { + // On an expired-credential error, run the configured refresh command and + // retry once — BEFORE recording a breaker failure, so a recoverable + // token expiry doesn't count toward opening the circuit. + if (!alreadyRetried && this.authRefresh && isAuthExpiry(err)) { + logger.warn("provider call failed with expired credentials — attempting auth refresh", { + provider: this.inner.name, + error: err instanceof Error ? err.message : String(err), + }); + // Scope this catch to the refresh command ONLY. If refresh succeeds, the + // retry runs outside the try so its own error (and breaker accounting) + // propagates normally — otherwise a failed retry would surface the stale + // auth-expiry error and double-count the breaker (once in the retried + // call, once here). + let refreshed = false; + try { + await this.authRefresh.run(); + refreshed = true; + } catch (refreshErr) { + logger.error("auth refresh command did not run", { + provider: this.inner.name, + reason: refreshErr instanceof Error ? refreshErr.message : String(refreshErr), + }); + } + if (refreshed) { + const result = await this.call(fn, true); + logger.info("auth refresh recovered the provider call", { + provider: this.inner.name, + }); + return result; + } + } this.breaker.recordFailure(); throw err; } diff --git a/src/types.ts b/src/types.ts index 1b53d1b61..83fd5b87a 100644 --- a/src/types.ts +++ b/src/types.ts @@ -147,7 +147,7 @@ export interface ProviderConfig { baseURL?: string; } -export type ProviderType = "agent-sdk" | "anthropic" | "gemini" | "openrouter" | "minimax" | "openai" | "noop"; +export type ProviderType = "agent-sdk" | "anthropic" | "bedrock" | "gemini" | "openrouter" | "minimax" | "openai" | "noop"; export interface MemoryProvider { name: string; diff --git a/test/auth-refresh.test.ts b/test/auth-refresh.test.ts new file mode 100644 index 000000000..c00c979bc --- /dev/null +++ b/test/auth-refresh.test.ts @@ -0,0 +1,211 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +// Controllable execFile mock: records call count and lets each test decide how +// the spawned command resolves (success / error / timeout), so single-flight and +// cooldown can be asserted on the REAL underlying call count rather than a spy on +// the public method — and without depending on a host `true`/`sleep` binary. +const execFileCalls: Array<{ cmd: string; args: string[] }> = []; +let execFileBehavior: (cmd: string) => Error | null = () => null; + +vi.mock("node:child_process", () => ({ + execFile: ( + cmd: string, + args: string[], + _opts: unknown, + cb: (err: Error | null) => void, + ) => { + execFileCalls.push({ cmd, args }); + // Resolve on a microtask so concurrent run() calls share one in-flight promise. + queueMicrotask(() => cb(execFileBehavior(cmd))); + }, +})); + +import { + AuthRefresh, + isAuthExpiry, + tokenizeCommand, +} from "../src/providers/auth-refresh.js"; +import { ResilientProvider } from "../src/providers/resilient.js"; +import type { MemoryProvider } from "../src/types.js"; + +describe("isAuthExpiry", () => { + it("matches AWS / SSO expiry signals", () => { + expect(isAuthExpiry(new Error("ExpiredTokenException: token expired"))).toBe(true); + expect(isAuthExpiry(new Error("The SSO session has expired"))).toBe(true); + expect(isAuthExpiry(new Error("Token is expired"))).toBe(true); + expect(isAuthExpiry({ name: "ExpiredToken", message: "" })).toBe(true); + expect(isAuthExpiry(new Error("The security token included in the request is expired"))).toBe(true); + // Real message from @aws-sdk after `aws sso logout` — note it says + // "not found or is invalid", never "expired", and includes the remediation + // hint. Both the SSO-session matcher and the `aws sso login` matcher catch it. + expect( + isAuthExpiry( + new Error( + "The SSO session token associated with profile=default was not found or is invalid. " + + "To refresh this SSO session run 'aws sso login' with the corresponding profile.", + ), + ), + ).toBe(true); + expect( + isAuthExpiry( + new Error( + "The SSO session associated with this profile has expired or is otherwise invalid.", + ), + ), + ).toBe(true); + }); + + it("does NOT match unrelated errors", () => { + expect(isAuthExpiry(new Error("ValidationException: model not found"))).toBe(false); + expect(isAuthExpiry(new Error("ThrottlingException"))).toBe(false); + expect(isAuthExpiry(new Error("AccessDeniedException: no model access"))).toBe(false); + expect(isAuthExpiry(new Error("connection reset"))).toBe(false); + expect(isAuthExpiry(undefined)).toBe(false); + }); +}); + +describe("tokenizeCommand", () => { + it("splits on whitespace", () => { + expect(tokenizeCommand("aws sso login --profile foo")).toEqual([ + "aws", "sso", "login", "--profile", "foo", + ]); + }); + + it("honors double and single quotes", () => { + expect(tokenizeCommand('aws sso login --profile "my profile"')).toEqual([ + "aws", "sso", "login", "--profile", "my profile", + ]); + expect(tokenizeCommand("cmd --x 'a b c'")).toEqual(["cmd", "--x", "a b c"]); + }); + + it("returns empty array for an empty command", () => { + expect(tokenizeCommand(" ")).toEqual([]); + }); +}); + +// A controllable fake provider + fake AuthRefresh so no real `aws` is spawned. +function fakeProvider(fn: () => Promise): MemoryProvider { + return { + name: "fake", + compress: fn, + summarize: fn, + }; +} + +function fakeRefresh(run: () => Promise): AuthRefresh { + return { run } as unknown as AuthRefresh; +} + +describe("ResilientProvider — auth-refresh retry", () => { + it("refreshes once and retries on an expired-token error, then succeeds", async () => { + let calls = 0; + const inner = fakeProvider(async () => { + calls += 1; + if (calls === 1) throw new Error("ExpiredTokenException"); + return "ok"; + }); + const run = vi.fn(async () => {}); + const provider = new ResilientProvider(inner, fakeRefresh(run)); + + const result = await provider.compress("s", "u"); + expect(result).toBe("ok"); + expect(calls).toBe(2); + expect(run).toHaveBeenCalledTimes(1); + }); + + it("does NOT refresh on a non-expiry error", async () => { + const inner = fakeProvider(async () => { + throw new Error("ValidationException"); + }); + const run = vi.fn(async () => {}); + const provider = new ResilientProvider(inner, fakeRefresh(run)); + + await expect(provider.compress("s", "u")).rejects.toThrow("ValidationException"); + expect(run).not.toHaveBeenCalled(); + }); + + it("retries at most once — propagates if the post-refresh call also expires", async () => { + let calls = 0; + const inner = fakeProvider(async () => { + calls += 1; + throw new Error("ExpiredTokenException"); + }); + const run = vi.fn(async () => {}); + const provider = new ResilientProvider(inner, fakeRefresh(run)); + + await expect(provider.compress("s", "u")).rejects.toThrow("ExpiredTokenException"); + expect(calls).toBe(2); // original + one retry, no more + expect(run).toHaveBeenCalledTimes(1); + }); + + it("propagates the original error if the refresh command itself fails", async () => { + const inner = fakeProvider(async () => { + throw new Error("ExpiredTokenException"); + }); + const run = vi.fn(async () => { + throw new Error("aws sso login failed"); + }); + const provider = new ResilientProvider(inner, fakeRefresh(run)); + + await expect(provider.compress("s", "u")).rejects.toThrow("ExpiredTokenException"); + expect(run).toHaveBeenCalledTimes(1); + }); + + it("behaves exactly as before when no AuthRefresh is configured (regression guard)", async () => { + const inner = fakeProvider(async () => { + throw new Error("ExpiredTokenException"); + }); + const provider = new ResilientProvider(inner); // no refresh + await expect(provider.compress("s", "u")).rejects.toThrow("ExpiredTokenException"); + }); +}); + +describe("AuthRefresh — single-flight + cooldown", () => { + beforeEach(() => { + execFileCalls.length = 0; + execFileBehavior = () => null; // default: command succeeds + }); + afterEach(() => { + execFileBehavior = () => null; + }); + + it("coalesces concurrent calls into a single command run (single-flight)", async () => { + const refresh = new AuthRefresh({ command: "aws sso login --profile p" }); + // Fire three concurrently; they share one in-flight promise, so execFile — + // the REAL underlying spawn — must run exactly once. + await Promise.all([refresh.run(), refresh.run(), refresh.run()]); + expect(execFileCalls).toHaveLength(1); + expect(execFileCalls[0]).toEqual({ cmd: "aws", args: ["sso", "login", "--profile", "p"] }); + }); + + it("rejects an empty command", async () => { + const refresh = new AuthRefresh({ command: " " }); + await expect(refresh.run()).rejects.toThrow(/empty/); + expect(execFileCalls).toHaveLength(0); + }); + + it("enforces a cooldown between sequential attempts (no second spawn)", async () => { + const refresh = new AuthRefresh({ command: "aws sso login", cooldownMs: 60_000 }); + await refresh.run(); // first succeeds → 1 spawn + await expect(refresh.run()).rejects.toThrow(/cooldown/); + expect(execFileCalls).toHaveLength(1); // cooldown blocked the second spawn + }); + + it("does NOT relaunch after a timeout (post-timeout suppression window)", async () => { + // Make the mocked command resolve with a timeout-shaped error (execFile sets + // killed:true + SIGTERM on timeout). cooldownMs:0 isolates the suppression + // path: any rejection on the next run() must come from post-timeout backoff. + execFileBehavior = () => + Object.assign(new Error("timed out"), { killed: true, signal: "SIGTERM" }); + const refresh = new AuthRefresh({ + command: "aws sso login", + timeoutMs: 50, + cooldownMs: 0, + postTimeoutCooldownMs: 60_000, + }); + await expect(refresh.run()).rejects.toThrow(); // times out → 1 spawn + // Second attempt must be suppressed, not relaunched (no new stale login). + await expect(refresh.run()).rejects.toThrow(/suppress|timed out/i); + expect(execFileCalls).toHaveLength(1); // suppression blocked the relaunch + }); +}); diff --git a/test/bedrock-embedding-provider.test.ts b/test/bedrock-embedding-provider.test.ts new file mode 100644 index 000000000..3c3e43837 --- /dev/null +++ b/test/bedrock-embedding-provider.test.ts @@ -0,0 +1,190 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +// Capture the bodies sent to InvokeModel and return canned responses, so no +// real AWS call is made. The mock records each request body for assertions. +const sentBodies: Array> = []; +let cannedResponse: (body: Record) => unknown; + +vi.mock("@aws-sdk/client-bedrock-runtime", () => { + class InvokeModelCommand { + input: { body: string; modelId: string }; + constructor(input: { body: string; modelId: string }) { + this.input = input; + } + } + class BedrockRuntimeClient { + config: unknown; + constructor(config: unknown) { + this.config = config; + } + async send(cmd: InvokeModelCommand) { + const body = JSON.parse(cmd.input.body) as Record; + sentBodies.push(body); + const payload = cannedResponse(body); + return { body: new TextEncoder().encode(JSON.stringify(payload)) }; + } + } + return { BedrockRuntimeClient, InvokeModelCommand }; +}); + +import { BedrockEmbeddingProvider } from "../src/providers/embedding/bedrock.js"; +import { detectEmbeddingProvider } from "../src/config.js"; + +const ENV_KEYS = [ + "AWS_REGION", + "AWS_BEDROCK_EMBEDDING_MODEL", + "AWS_BEDROCK_EMBEDDING_DIMENSIONS", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "EMBEDDING_PROVIDER", + "AWS_BEDROCK", + "OPENAI_API_KEY", +] as const; + +describe("BedrockEmbeddingProvider", () => { + const saved: Record = {}; + + beforeEach(() => { + sentBodies.length = 0; + for (const k of ENV_KEYS) { + saved[k] = process.env[k]; + delete process.env[k]; + } + process.env["AWS_REGION"] = "us-east-2"; + // Default canned response: float vectors of the right length, one per text. + cannedResponse = (body) => { + const dim = (body.output_dimension as number) ?? 1024; + const texts = (body.texts as string[]) ?? [body.inputText as string]; + return { embeddings: { float: texts.map(() => new Array(dim).fill(0.1)) } }; + }; + }); + + afterEach(() => { + for (const k of ENV_KEYS) { + if (saved[k] === undefined) delete process.env[k]; + else process.env[k] = saved[k]; + } + }); + + it("defaults to cohere.embed-v4:0 at 1024 dimensions", () => { + const p = new BedrockEmbeddingProvider(); + expect(p.name).toBe("bedrock"); + expect(p.dimensions).toBe(1024); + }); + + // Note: the AWS_REGION-required guard is not unit-tested here because + // getEnvVar merges the real ~/.agentmemory/.env (which may set AWS_REGION), + // so the absence can't be reliably simulated through the merged-env path. + + it("honors AWS_BEDROCK_EMBEDDING_DIMENSIONS override", () => { + process.env["AWS_BEDROCK_EMBEDDING_DIMENSIONS"] = "512"; + const p = new BedrockEmbeddingProvider(); + expect(p.dimensions).toBe(512); + }); + + it("throws for an unknown model with no dimensions override", () => { + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "cohere.embed-future-v9:0"; + expect(() => new BedrockEmbeddingProvider()).toThrow(/AWS_BEDROCK_EMBEDDING_DIMENSIONS/); + }); + + it("rejects a non-cohere/non-titan model family", () => { + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "meta.llama-embed"; + process.env["AWS_BEDROCK_EMBEDDING_DIMENSIONS"] = "1024"; + expect(() => new BedrockEmbeddingProvider()).toThrow(/cohere\.|titan/); + }); + + it("accepts a us.-prefixed cross-region inference profile ID (family + dims resolve)", () => { + // cohere.embed-v4:0 is INFERENCE_PROFILE-only in some regions, so users set + // us.cohere.embed-v4:0 — family detection and known-dims must see through the + // geo prefix rather than demanding a dimensions override or throwing. + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "us.cohere.embed-v4:0"; + const p = new BedrockEmbeddingProvider(); + expect(p.dimensions).toBe(1024); + }); + + it("uses the Cohere body shape for a global.-prefixed profile ID", async () => { + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "global.cohere.embed-v4:0"; + const p = new BedrockEmbeddingProvider(); + await p.embedBatch(["x"]); + expect(sentBodies[0]).toMatchObject({ + input_type: "search_document", + embedding_types: ["float"], + output_dimension: 1024, + }); + }); + + it("uses the Cohere body shape and reads embeddings.float (v4)", async () => { + const p = new BedrockEmbeddingProvider(); + const vecs = await p.embedBatch(["hello", "world"]); + expect(vecs).toHaveLength(2); + expect(vecs[0]).toBeInstanceOf(Float32Array); + expect(vecs[0].length).toBe(1024); + // v4 request: input_type required, float type, explicit output_dimension. + expect(sentBodies[0]).toMatchObject({ + input_type: "search_document", + texts: ["hello", "world"], + embedding_types: ["float"], + output_dimension: 1024, + }); + }); + + it("throws when the response returns fewer vectors than inputs (no silent misalignment)", async () => { + // Two inputs, but the model returns one vector — must fail fast rather than + // misalign texts to vectors downstream. + cannedResponse = () => ({ embeddings: { float: [new Array(1024).fill(0.1)] } }); + const p = new BedrockEmbeddingProvider(); + await expect(p.embedBatch(["one", "two"])).rejects.toThrow(/1 vectors for 2 inputs|misalign/); + }); + + it("parses the bare-array response shape for Cohere v3", async () => { + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "cohere.embed-english-v3"; + cannedResponse = (body) => { + const texts = (body.texts as string[]) ?? []; + return { embeddings: texts.map(() => new Array(1024).fill(0.2)) }; + }; + const p = new BedrockEmbeddingProvider(); + const vecs = await p.embedBatch(["a"]); + expect(vecs[0].length).toBe(1024); + // v3 does not send output_dimension. + expect(sentBodies[0].output_dimension).toBeUndefined(); + }); + + it("uses the Titan body shape (inputText) and fans out one call per text", async () => { + process.env["AWS_BEDROCK_EMBEDDING_MODEL"] = "amazon.titan-embed-text-v2:0"; + cannedResponse = (body) => ({ + embedding: new Array((body.dimensions as number) ?? 1024).fill(0.3), + }); + const p = new BedrockEmbeddingProvider(); + const vecs = await p.embedBatch(["one", "two", "three"]); + expect(vecs).toHaveLength(3); + expect(vecs[0].length).toBe(1024); + expect(sentBodies).toHaveLength(3); // one InvokeModel call per input + expect(sentBodies[0]).toMatchObject({ + inputText: expect.any(String), + dimensions: 1024, + normalize: true, + }); + }); + + it("passes explicit static creds only when both halves are set", () => { + process.env["AWS_ACCESS_KEY_ID"] = "AKIA"; + process.env["AWS_SECRET_ACCESS_KEY"] = "secret"; + const p = new BedrockEmbeddingProvider(); + const cfg = (p as unknown as { client: { config: { credentials?: unknown } } }) + .client.config; + expect(cfg.credentials).toMatchObject({ accessKeyId: "AKIA", secretAccessKey: "secret" }); + }); +}); + +describe("detectEmbeddingProvider — bedrock", () => { + it("selects bedrock when EMBEDDING_PROVIDER=bedrock", () => { + expect(detectEmbeddingProvider({ EMBEDDING_PROVIDER: "bedrock" })).toBe("bedrock"); + }); + + it("does NOT auto-select bedrock from AWS_BEDROCK=true (local-embeddings stays)", () => { + // AWS_BEDROCK opts into the LLM provider only; embeddings need an explicit + // EMBEDDING_PROVIDER. With no embedding key set, detection returns null + // (caller falls back to local). + expect(detectEmbeddingProvider({ AWS_BEDROCK: "true" })).toBeNull(); + }); +}); diff --git a/test/bedrock-provider.test.ts b/test/bedrock-provider.test.ts new file mode 100644 index 000000000..437ebc8e7 --- /dev/null +++ b/test/bedrock-provider.test.ts @@ -0,0 +1,119 @@ +import { describe, expect, it, afterEach, beforeEach } from "vitest"; +import { BedrockProvider } from "../src/providers/bedrock.js"; +import { detectProvider } from "../src/config.js"; + +// Env keys this suite mutates — saved/restored so tests don't leak into each +// other or pick up the developer's real ~/.agentmemory/.env values. +const ENV_KEYS = [ + "AWS_BEDROCK", + "AWS_REGION", + "AWS_BEDROCK_MODEL", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "OPENAI_API_KEY", + "OPENAI_API_KEY_FOR_LLM", +] as const; + +describe("BedrockProvider", () => { + const saved: Record = {}; + + beforeEach(() => { + for (const k of ENV_KEYS) { + saved[k] = process.env[k]; + delete process.env[k]; + } + }); + + afterEach(() => { + for (const k of ENV_KEYS) { + if (saved[k] === undefined) delete process.env[k]; + else process.env[k] = saved[k]; + } + }); + + it("constructs with only a region (no explicit keys) — relies on the credential chain", () => { + expect( + () => new BedrockProvider("anthropic.claude-haiku-4-5-20251001-v1:0", 800, "us-east-1"), + ).not.toThrow(); + }); + + it("constructs with explicit static keys when present", () => { + process.env["AWS_ACCESS_KEY_ID"] = "AKIAEXAMPLE"; + process.env["AWS_SECRET_ACCESS_KEY"] = "secret"; + const provider = new BedrockProvider("model-id", 800, "eu-west-1"); + const client = (provider as unknown as { client: { awsAccessKey: string | null } }).client; + expect(client.awsAccessKey).toBe("AKIAEXAMPLE"); + }); + + it("ignores a lone access key (omits both, falls back to the credential chain)", () => { + // Only one of the pair set — must NOT pass it through (the SDK deprecates + // partial static creds); the provider chain handles it instead. + process.env["AWS_ACCESS_KEY_ID"] = "AKIAEXAMPLE"; + const provider = new BedrockProvider("model-id", 800, "us-east-1"); + const client = (provider as unknown as { client: { awsAccessKey: string | null } }).client; + expect(client.awsAccessKey).toBeNull(); + }); + + it("threads the region through to the client", () => { + const provider = new BedrockProvider("model-id", 800, "ap-southeast-2"); + const client = (provider as unknown as { client: { awsRegion: string } }).client; + expect(client.awsRegion).toBe("ap-southeast-2"); + }); +}); + +describe("detectProvider — bedrock branch", () => { + // Tests the pure detection function with explicit env maps, so they are + // independent of the developer's real ~/.agentmemory/.env. + it("selects bedrock when AWS_BEDROCK=true and AWS_REGION is set", () => { + const config = detectProvider({ AWS_BEDROCK: "true", AWS_REGION: "us-east-1" }); + expect(config.provider).toBe("bedrock"); + }); + + it("defaults the model to Claude Haiku 4.5 when AWS_BEDROCK_MODEL is unset", () => { + const config = detectProvider({ AWS_BEDROCK: "true", AWS_REGION: "us-east-1" }); + expect(config.model).toBe("anthropic.claude-haiku-4-5-20251001-v1:0"); + }); + + it("honors an explicit AWS_BEDROCK_MODEL (e.g. a us.-prefixed inference profile)", () => { + const config = detectProvider({ + AWS_BEDROCK: "true", + AWS_REGION: "us-east-1", + AWS_BEDROCK_MODEL: "us.anthropic.claude-haiku-4-5-20251001-v1:0", + }); + expect(config.model).toBe("us.anthropic.claude-haiku-4-5-20251001-v1:0"); + }); + + it("does NOT select bedrock when AWS_BEDROCK is unset, even with an OpenAI key (regression guard)", () => { + const config = detectProvider({ OPENAI_API_KEY: "sk-test" }); + expect(config.provider).toBe("openai"); + }); + + it("does NOT select bedrock when AWS_BEDROCK has any value other than the literal 'true'", () => { + const config = detectProvider({ + AWS_BEDROCK: "1", + AWS_REGION: "us-east-1", + OPENAI_API_KEY: "sk-test", + }); + expect(config.provider).not.toBe("bedrock"); + }); + + it("selects bedrock for a case-insensitive AWS_BEDROCK (True/TRUE/ true )", () => { + for (const flag of ["True", "TRUE", " true "]) { + const config = detectProvider({ AWS_BEDROCK: flag, AWS_REGION: "us-east-1" }); + expect(config.provider).toBe("bedrock"); + } + }); + + it("rejects bedrock when AWS_REGION is unset and falls through to the next provider", () => { + // AWS_BEDROCK=true but no region → Bedrock can't be constructed, so detection + // must not return an unusable bedrock config; it falls through to OpenAI here. + const config = detectProvider({ AWS_BEDROCK: "true", OPENAI_API_KEY: "sk-test" }); + expect(config.provider).toBe("openai"); + }); + + it("falls through to noop when AWS_BEDROCK=true but no region and no other provider", () => { + const config = detectProvider({ AWS_BEDROCK: "true" }); + expect(config.provider).toBe("noop"); + }); +}); diff --git a/test/compress-file.test.ts b/test/compress-file.test.ts index 9b6820b3e..5efdae444 100644 --- a/test/compress-file.test.ts +++ b/test/compress-file.test.ts @@ -193,6 +193,27 @@ describe("mem::compress-file", () => { expect(fileStore.get("/tmp/guide.original.md")).toBeUndefined(); }); + it("surfaces the provider error message instead of letting it escape (Bedrock hint)", async () => { + const path = "/tmp/notes.md"; + fileStore.set(path, "# Title\n\nLong original body."); + summarize.mockRejectedValue( + new Error( + 'Bedrock model "anthropic.claude-haiku-4-5-20251001-v1:0" could not be invoked: ' + + "set AWS_BEDROCK_MODEL to the us./eu.-prefixed cross-region inference profile ID.", + ), + ); + + const result = (await sdk.trigger("mem::compress-file", { + filePath: path, + })) as { success: boolean; error: string }; + + expect(result.success).toBe(false); + expect(result.error).toContain("cross-region inference profile"); + // The original file is untouched and no backup is written on provider failure. + expect(fileStore.get(path)).toBe("# Title\n\nLong original body."); + expect(fileStore.get("/tmp/notes.original.md")).toBeUndefined(); + }); + it("uses a distinct backup path for *.original.md inputs", async () => { const path = "/tmp/notes.original.md"; fileStore.set(path, "# Title\n\nLong original body.");