diff --git a/.changeset/hot-trees-sing.md b/.changeset/hot-trees-sing.md new file mode 100644 index 000000000..74df1dd00 --- /dev/null +++ b/.changeset/hot-trees-sing.md @@ -0,0 +1,12 @@ +--- +'@modelcontextprotocol/express': patch +'@modelcontextprotocol/hono': patch +'@modelcontextprotocol/node': patch +'@modelcontextprotocol/eslint-config': patch +'@modelcontextprotocol/test-integration': patch +'@modelcontextprotocol/client': patch +'@modelcontextprotocol/server': patch +'@modelcontextprotocol/core': patch +--- + +add context API to tool, prompt, resource callbacks, linting diff --git a/CLAUDE.md b/CLAUDE.md index 16c7058a5..b2fc20148 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,8 +29,7 @@ When making breaking changes, document them in **both**: - `docs/migration.md` — human-readable guide with before/after code examples - `docs/migration-SKILL.md` — LLM-optimized mapping tables for mechanical migration -Include what changed, why, and how to migrate. Search for related sections and group -related changes together rather than adding new standalone sections. +Include what changed, why, and how to migrate. Search for related sections and group related changes together rather than adding new standalone sections. ## Code Style Guidelines @@ -146,7 +145,7 @@ When a request arrives from the remote side: 2. **`Protocol.connect()`** routes to `_onrequest()`, `_onresponse()`, or `_onnotification()` 3. **`Protocol._onrequest()`**: - Looks up handler in `_requestHandlers` map (keyed by method name) - - Creates `RequestHandlerExtra` with `signal`, `sessionId`, `sendNotification`, `sendRequest` + - Creates a context object (`ServerContext` or `ClientContext`) via `createRequestContext()` - Invokes handler, sends JSON-RPC response back via transport 4. **Handler** was registered via `setRequestHandler('method', handler)` @@ -154,29 +153,55 @@ When a request arrives from the remote side: ```typescript // In Client (for server→client requests like sampling, elicitation) -client.setRequestHandler('sampling/createMessage', async (request, extra) => { +client.setRequestHandler('sampling/createMessage', async (request, ctx) => { // Handle sampling request from server return { role: "assistant", content: {...}, model: "..." }; }); // In Server (for client→server requests like tools/call) -server.setRequestHandler('tools/call', async (request, extra) => { +server.setRequestHandler('tools/call', async (request, ctx) => { // Handle tool call from client return { content: [...] }; }); ``` -### Request Handler Extra - -The `extra` parameter in handlers (`RequestHandlerExtra`) provides: - -- `signal`: AbortSignal for cancellation -- `sessionId`: Transport session identifier -- `authInfo`: Validated auth token info (if authenticated) -- `requestId`: JSON-RPC message ID -- `sendNotification(notification)`: Send related notification back -- `sendRequest(request, schema)`: Send related request (for bidirectional flows) -- `taskStore`: Task storage interface (if tasks enabled) +### Request Handler Context + +The `ctx` parameter in handlers provides a structured context with grouped fields: + +**Common structure (both Client and Server)**: + +- `ctx.sessionId`: Transport session identifier (top-level) +- `ctx.mcpReq`: MCP protocol context + - `id`: JSON-RPC message ID + - `method`: The method being called + - `_meta`: Request metadata + - `signal`: AbortSignal for cancellation + - `send(request, schema, options?)`: Send request (for bidirectional flows) +- `ctx.http`: HTTP request context (optional, present for HTTP transports) + - `authInfo`: Validated auth token info (if authenticated) +- `ctx.task`: Task context (when tasks are enabled) + - `id`: Current task ID (updates after `store.createTask()`) + - `store`: Request-scoped task store (`RequestTaskStore`) + - `requestedTtl`: Requested TTL for the task +- `ctx.notification`: Notification context + - `send(notification)`: Send notification back + +**Server-specific additions**: + +- `ctx.http`: Extended with additional fields + - `req`: Raw fetch Request object (access to URL, headers, etc.) + - `closeSSE?()`: Close SSE stream for polling + - `closeStandaloneSSE?()`: Close standalone SSE stream +- `ctx.mcpReq`: Extended with server-to-client request methods + - `requestSampling(params, options?)`: Request sampling from client + - `elicitInput(params, options?)`: Request user input from client +- `ctx.notification`: Extended with logging methods + - `log(params)`: Send logging notification + - `debug(message, extraLogData?)`: Send debug log + - `info(message, extraLogData?)`: Send info log + - `warning(message, extraLogData?)`: Send warning log + - `error(message, extraLogData?)`: Send error log ### Capability Checking @@ -207,7 +232,7 @@ const result = await server.createMessage({ }); // Client must have registered handler: -client.setRequestHandler('sampling/createMessage', async (request, extra) => { +client.setRequestHandler('sampling/createMessage', async (request, ctx) => { // Client-side LLM call return { role: "assistant", content: {...} }; }); @@ -218,8 +243,8 @@ client.setRequestHandler('sampling/createMessage', async (request, extra) => { ### Request Handler Registration (Low-Level Server) ```typescript -server.setRequestHandler('tools/call', async (request, extra) => { - // extra contains sessionId, authInfo, sendNotification, etc. +server.setRequestHandler('tools/call', async (request, ctx) => { + // ctx provides mcpReq, http, task, notification return { /* result */ }; @@ -229,7 +254,7 @@ server.setRequestHandler('tools/call', async (request, extra) => { ### Tool Registration (High-Level McpServer) ```typescript -mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, extra) => { +mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, ctx) => { return { content: [{ type: 'text', text: 'result' }] }; }); ``` diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 53573a7bc..57e07c69f 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -20,13 +20,13 @@ Remove the old package and install only what you need: npm uninstall @modelcontextprotocol/sdk ``` -| You need | Install | -|----------|---------| -| Client only | `npm install @modelcontextprotocol/client` | -| Server only | `npm install @modelcontextprotocol/server` | -| Server + Node.js HTTP | `npm install @modelcontextprotocol/server @modelcontextprotocol/node` | -| Server + Express | `npm install @modelcontextprotocol/server @modelcontextprotocol/express` | -| Server + Hono | `npm install @modelcontextprotocol/server @modelcontextprotocol/hono` | +| You need | Install | +| --------------------- | ------------------------------------------------------------------------ | +| Client only | `npm install @modelcontextprotocol/client` | +| Server only | `npm install @modelcontextprotocol/server` | +| Server + Node.js HTTP | `npm install @modelcontextprotocol/server @modelcontextprotocol/node` | +| Server + Express | `npm install @modelcontextprotocol/server @modelcontextprotocol/express` | +| Server + Hono | `npm install @modelcontextprotocol/server @modelcontextprotocol/hono` | `@modelcontextprotocol/core` is installed automatically as a dependency. @@ -36,65 +36,67 @@ Replace all `@modelcontextprotocol/sdk/...` imports using this table. ### Client imports -| v1 import path | v2 package | -|----------------|------------| -| `@modelcontextprotocol/sdk/client/index.js` | `@modelcontextprotocol/client` | -| `@modelcontextprotocol/sdk/client/auth.js` | `@modelcontextprotocol/client` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ------------------------------ | +| `@modelcontextprotocol/sdk/client/index.js` | `@modelcontextprotocol/client` | +| `@modelcontextprotocol/sdk/client/auth.js` | `@modelcontextprotocol/client` | | `@modelcontextprotocol/sdk/client/streamableHttp.js` | `@modelcontextprotocol/client` | -| `@modelcontextprotocol/sdk/client/sse.js` | `@modelcontextprotocol/client` | -| `@modelcontextprotocol/sdk/client/stdio.js` | `@modelcontextprotocol/client` | -| `@modelcontextprotocol/sdk/client/websocket.js` | `@modelcontextprotocol/client` | +| `@modelcontextprotocol/sdk/client/sse.js` | `@modelcontextprotocol/client` | +| `@modelcontextprotocol/sdk/client/stdio.js` | `@modelcontextprotocol/client` | +| `@modelcontextprotocol/sdk/client/websocket.js` | `@modelcontextprotocol/client` | ### Server imports -| v1 import path | v2 package | -|----------------|------------| -| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ----------------------------------------------------------------------------------- | +| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | | `@modelcontextprotocol/sdk/server/streamableHttp.js` | `@modelcontextprotocol/node` (class renamed to `NodeStreamableHTTPServerTransport`) | -| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | -| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | -| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | +| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | +| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | +| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | ### Types / shared imports -| v1 import path | v2 package | -|----------------|------------| -| `@modelcontextprotocol/sdk/types.js` | `@modelcontextprotocol/core` | -| `@modelcontextprotocol/sdk/shared/protocol.js` | `@modelcontextprotocol/core` | -| `@modelcontextprotocol/sdk/shared/transport.js` | `@modelcontextprotocol/core` | -| `@modelcontextprotocol/sdk/shared/stdio.js` | `@modelcontextprotocol/core` | +| v1 import path | v2 package | +| ------------------------------------------------- | ---------------------------- | +| `@modelcontextprotocol/sdk/types.js` | `@modelcontextprotocol/core` | +| `@modelcontextprotocol/sdk/shared/protocol.js` | `@modelcontextprotocol/core` | +| `@modelcontextprotocol/sdk/shared/transport.js` | `@modelcontextprotocol/core` | +| `@modelcontextprotocol/sdk/shared/stdio.js` | `@modelcontextprotocol/core` | | `@modelcontextprotocol/sdk/shared/uriTemplate.js` | `@modelcontextprotocol/core` | -| `@modelcontextprotocol/sdk/shared/auth.js` | `@modelcontextprotocol/core` | +| `@modelcontextprotocol/sdk/shared/auth.js` | `@modelcontextprotocol/core` | Notes: + - `@modelcontextprotocol/client` and `@modelcontextprotocol/server` both re-export everything from `@modelcontextprotocol/core`, so you can import types from whichever package you already depend on. - When multiple v1 imports map to the same v2 package, consolidate them into a single import statement. - If code imports from `sdk/client/...`, install `@modelcontextprotocol/client`. If from `sdk/server/...`, install `@modelcontextprotocol/server`. If from `sdk/types.js` or `sdk/shared/...` only, install `@modelcontextprotocol/core`. ## 4. Renamed Symbols -| v1 symbol | v2 symbol | v2 package | -|-----------|-----------|------------| +| v1 symbol | v2 symbol | v2 package | +| ------------------------------- | ----------------------------------- | ---------------------------- | | `StreamableHTTPServerTransport` | `NodeStreamableHTTPServerTransport` | `@modelcontextprotocol/node` | ## 5. Removed / Renamed Type Aliases and Symbols -| v1 (removed) | v2 (replacement) | -|--------------|------------------| -| `JSONRPCError` | `JSONRPCErrorResponse` | -| `JSONRPCErrorSchema` | `JSONRPCErrorResponseSchema` | -| `isJSONRPCError` | `isJSONRPCErrorResponse` | -| `isJSONRPCResponse` | `isJSONRPCResultResponse` | -| `ResourceReference` | `ResourceTemplateReference` | -| `ResourceReferenceSchema` | `ResourceTemplateReferenceSchema` | -| `IsomorphicHeaders` | REMOVED (use Web Standard `Headers`) | +| v1 (removed) | v2 (replacement) | +| ---------------------------------------- | ------------------------------------------------ | +| `JSONRPCError` | `JSONRPCErrorResponse` | +| `JSONRPCErrorSchema` | `JSONRPCErrorResponseSchema` | +| `isJSONRPCError` | `isJSONRPCErrorResponse` | +| `isJSONRPCResponse` | `isJSONRPCResultResponse` | +| `ResourceReference` | `ResourceTemplateReference` | +| `ResourceReferenceSchema` | `ResourceTemplateReferenceSchema` | +| `IsomorphicHeaders` | REMOVED (use Web Standard `Headers`) | | `AuthInfo` (from `server/auth/types.js`) | `AuthInfo` (now in `@modelcontextprotocol/core`) | All other symbols from `@modelcontextprotocol/sdk/types.js` retain their original names (e.g., `CallToolResultSchema`, `ListToolsResultSchema`, etc.). -**Unchanged APIs** (only import paths changed): `Client` constructor and methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. +**Unchanged APIs** (only import paths changed): `Client` constructor and methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod +schemas, all callback return types. ## 6. McpServer API Changes @@ -105,21 +107,25 @@ The variadic `.tool()`, `.prompt()`, `.resource()` methods are removed. Use the ```typescript // v1: server.tool(name, schema, callback) server.tool('greet', { name: z.string() }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // v1: server.tool(name, description, schema, callback) server.tool('greet', 'Greet a user', { name: z.string() }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // v2: server.registerTool(name, config, callback) -server.registerTool('greet', { - description: 'Greet a user', - inputSchema: { name: z.string() }, -}, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; -}); +server.registerTool( + 'greet', + { + description: 'Greet a user', + inputSchema: { name: z.string() } + }, + async ({ name }) => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + } +); ``` Config object fields: `title?`, `description?`, `inputSchema?`, `outputSchema?`, `annotations?`, `_meta?` @@ -129,15 +135,19 @@ Config object fields: `title?`, `description?`, `inputSchema?`, `outputSchema?`, ```typescript // v1: server.prompt(name, schema, callback) server.prompt('summarize', { text: z.string() }, async ({ text }) => { - return { messages: [{ role: 'user', content: { type: 'text', text } }] }; + return { messages: [{ role: 'user', content: { type: 'text', text } }] }; }); // v2: server.registerPrompt(name, config, callback) -server.registerPrompt('summarize', { - argsSchema: { text: z.string() }, -}, async ({ text }) => { - return { messages: [{ role: 'user', content: { type: 'text', text } }] }; -}); +server.registerPrompt( + 'summarize', + { + argsSchema: { text: z.string() } + }, + async ({ text }) => { + return { messages: [{ role: 'user', content: { type: 'text', text } }] }; + } +); ``` Config object fields: `title?`, `description?`, `argsSchema?` @@ -146,13 +156,13 @@ Config object fields: `title?`, `description?`, `argsSchema?` ```typescript // v1: server.resource(name, uri, callback) -server.resource('config', 'config://app', async (uri) => { - return { contents: [{ uri: uri.href, text: '{}' }] }; +server.resource('config', 'config://app', async uri => { + return { contents: [{ uri: uri.href, text: '{}' }] }; }); // v2: server.registerResource(name, uri, metadata, callback) -server.registerResource('config', 'config://app', {}, async (uri) => { - return { contents: [{ uri: uri.href, text: '{}' }] }; +server.registerResource('config', 'config://app', {}, async uri => { + return { contents: [{ uri: uri.href, text: '{}' }] }; }); ``` @@ -180,7 +190,8 @@ extra.requestInfo?.headers.get('mcp-session-id') ### Server-side auth -All server OAuth exports removed: `mcpAuthRouter`, `OAuthServerProvider`, `OAuthTokenVerifier`, `requireBearerAuth`, `authenticateClient`, `ProxyOAuthServerProvider`, `allowedMethods`, and associated types. Use an external auth library (e.g., `better-auth`). See `examples/server/src/` for demos. +All server OAuth exports removed: `mcpAuthRouter`, `OAuthServerProvider`, `OAuthTokenVerifier`, `requireBearerAuth`, `authenticateClient`, `ProxyOAuthServerProvider`, `allowedMethods`, and associated types. Use an external auth library (e.g., `better-auth`). See +`examples/server/src/` for demos. ### Host header validation (Express) @@ -214,26 +225,26 @@ server.setNotificationHandler('notifications/message', (notification) => { ... } Schema to method string mapping: -| v1 Schema | v2 Method String | -|-----------|-----------------| -| `InitializeRequestSchema` | `'initialize'` | -| `CallToolRequestSchema` | `'tools/call'` | -| `ListToolsRequestSchema` | `'tools/list'` | -| `ListPromptsRequestSchema` | `'prompts/list'` | -| `GetPromptRequestSchema` | `'prompts/get'` | -| `ListResourcesRequestSchema` | `'resources/list'` | -| `ReadResourceRequestSchema` | `'resources/read'` | -| `CreateMessageRequestSchema` | `'sampling/createMessage'` | -| `ElicitRequestSchema` | `'elicitation/create'` | -| `SetLevelRequestSchema` | `'logging/setLevel'` | -| `PingRequestSchema` | `'ping'` | -| `LoggingMessageNotificationSchema` | `'notifications/message'` | -| `ToolListChangedNotificationSchema` | `'notifications/tools/list_changed'` | +| v1 Schema | v2 Method String | +| --------------------------------------- | ---------------------------------------- | +| `InitializeRequestSchema` | `'initialize'` | +| `CallToolRequestSchema` | `'tools/call'` | +| `ListToolsRequestSchema` | `'tools/list'` | +| `ListPromptsRequestSchema` | `'prompts/list'` | +| `GetPromptRequestSchema` | `'prompts/get'` | +| `ListResourcesRequestSchema` | `'resources/list'` | +| `ReadResourceRequestSchema` | `'resources/read'` | +| `CreateMessageRequestSchema` | `'sampling/createMessage'` | +| `ElicitRequestSchema` | `'elicitation/create'` | +| `SetLevelRequestSchema` | `'logging/setLevel'` | +| `PingRequestSchema` | `'ping'` | +| `LoggingMessageNotificationSchema` | `'notifications/message'` | +| `ToolListChangedNotificationSchema` | `'notifications/tools/list_changed'` | | `ResourceListChangedNotificationSchema` | `'notifications/resources/list_changed'` | -| `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | -| `ProgressNotificationSchema` | `'notifications/progress'` | -| `CancelledNotificationSchema` | `'notifications/cancelled'` | -| `InitializedNotificationSchema` | `'notifications/initialized'` | +| `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | +| `ProgressNotificationSchema` | `'notifications/progress'` | +| `CancelledNotificationSchema` | `'notifications/cancelled'` | +| `InitializedNotificationSchema` | `'notifications/initialized'` | Request/notification params remain fully typed. Remove unused schema imports after migration. @@ -241,6 +252,87 @@ Request/notification params remain fully typed. Remove unused schema imports aft `Client.listPrompts()`, `listResources()`, `listResourceTemplates()`, `listTools()` now return empty results when the server lacks the corresponding capability (instead of sending the request). Set `enforceStrictCapabilities: true` in `ClientOptions` to throw an error instead. +## 10. Context API (replaces `RequestHandlerExtra`) + +Tool, prompt, and resource callbacks now receive a structured context object (`ctx`) instead of the flat `extra` parameter. + +### Property Mapping + +| v1 (`extra.`) | v2 (`ctx.`) | +| ------------------------------------ | ---------------------------------- | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo?.headers` | `ctx.http?.req.headers` | +| `extra.sendNotification(n)` | `ctx.notification.send(n)` | +| `extra.sendRequest(r, s, o)` | `ctx.mcpReq.send(r, s, o)` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra.closeSSEStream?.()` | `ctx.http?.closeSSE?.()` | +| `extra.closeStandaloneSSEStream?.()` | `ctx.http?.closeStandaloneSSE?.()` | + +### Server-specific context methods + +| Method | Description | +| ------------------------------------------- | ---------------------------------- | +| `ctx.notification.log(params)` | Send logging message | +| `ctx.notification.debug(msg, data?)` | Send debug log | +| `ctx.notification.info(msg, data?)` | Send info log | +| `ctx.notification.warning(msg, data?)` | Send warning log | +| `ctx.notification.error(msg, data?)` | Send error log | +| `ctx.mcpReq.elicitInput(params, opts?)` | Request user input via elicitation | +| `ctx.mcpReq.requestSampling(params, opts?)` | Request LLM sampling from client | +| `ctx.http?.req` | Raw fetch `Request` object | + +### Context structure overview + +```typescript +// Common structure (client and server) +ctx.sessionId // top-level +ctx.mcpReq.id // request ID +ctx.mcpReq.method // request method +ctx.mcpReq._meta // request metadata +ctx.mcpReq.signal // abort signal +ctx.mcpReq.send(req, schema) // send related request +ctx.http?.authInfo // auth info (HTTP transports) +ctx.task?.id // task ID (if task-augmented) +ctx.task?.store // task store +ctx.task?.requestedTtl // task TTL +ctx.notification.send(n) // send notification + +// Server additions +ctx.http?.req // raw fetch Request object +ctx.http?.closeSSE?.() // close SSE stream +ctx.http?.closeStandaloneSSE?.() // close standalone SSE +ctx.notification.log(params) // logging +ctx.notification.debug/info/warning/error(msg, data?) +ctx.mcpReq.elicitInput(params, opts?) +ctx.mcpReq.requestSampling(params, opts?) +``` + +### Example migration + +```typescript +// v1 +server.registerTool('my-tool', { inputSchema: { q: z.string() } }, async ({ q }, extra) => { + if (extra.signal.aborted) throw new Error('Cancelled'); + await extra.sendNotification({ method: 'n', params: {} }); + if (extra.taskStore) await extra.taskStore.updateTaskStatus(extra.taskId!, 'running'); + return { content: [{ type: 'text', text: 'Done' }] }; +}); + +// v2 +server.registerTool('my-tool', { inputSchema: { q: z.string() } }, async ({ q }, ctx) => { + if (ctx.mcpReq.signal.aborted) throw new Error('Cancelled'); + await ctx.notification.send({ method: 'n', params: {} }); + if (ctx.task) await ctx.task.store.updateTaskStatus(ctx.task.id, 'running'); + return { content: [{ type: 'text', text: 'Done' }] }; +}); +``` + ## 11. Migration Steps (apply in this order) 1. Update `package.json`: `npm uninstall @modelcontextprotocol/sdk`, install the appropriate v2 packages @@ -252,4 +344,5 @@ Request/notification params remain fully typed. Remove unused schema imports aft 7. If using server SSE transport, migrate to Streamable HTTP 8. If using server auth from the SDK, migrate to an external auth library 9. If relying on `listTools()`/`listPrompts()`/etc. throwing on missing capabilities, set `enforceStrictCapabilities: true` -10. Verify: build with `tsc` / run tests +10. Update tool/prompt/resource callbacks: rename `extra` parameter to `ctx` and update property access per section 10 +11. Verify: build with `tsc` / run tests diff --git a/docs/migration.md b/docs/migration.md index 3ea4a3085..c6eb39295 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -4,7 +4,8 @@ This guide covers the breaking changes introduced in v2 of the MCP TypeScript SD ## Overview -Version 2 of the MCP TypeScript SDK introduces several breaking changes to improve modularity, reduce dependency bloat, and provide a cleaner API surface. The biggest change is the split from a single `@modelcontextprotocol/sdk` package into separate `@modelcontextprotocol/core`, `@modelcontextprotocol/client`, and `@modelcontextprotocol/server` packages. +Version 2 of the MCP TypeScript SDK introduces several breaking changes to improve modularity, reduce dependency bloat, and provide a cleaner API surface. The biggest change is the split from a single `@modelcontextprotocol/sdk` package into separate `@modelcontextprotocol/core`, +`@modelcontextprotocol/client`, and `@modelcontextprotocol/server` packages. ## Breaking Changes @@ -12,11 +13,11 @@ Version 2 of the MCP TypeScript SDK introduces several breaking changes to impro The single `@modelcontextprotocol/sdk` package has been split into three packages: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| --------------------------- | ---------------------------------------------------------- | | `@modelcontextprotocol/sdk` | `@modelcontextprotocol/core` (types, protocol, transports) | -| | `@modelcontextprotocol/client` (client implementation) | -| | `@modelcontextprotocol/server` (server implementation) | +| | `@modelcontextprotocol/client` (client implementation) | +| | `@modelcontextprotocol/server` (server implementation) | Remove the old package and install only the packages you need: @@ -64,6 +65,7 @@ Note: `@modelcontextprotocol/client` and `@modelcontextprotocol/server` both re- v2 requires **Node.js 20+** and ships **ESM only** (no more CommonJS builds). If your project uses CommonJS (`require()`), you will need to either: + - Migrate to ESM (`import`/`export`) - Use dynamic `import()` to load the SDK @@ -71,11 +73,11 @@ If your project uses CommonJS (`require()`), you will need to either: The server package no longer depends on Express or Hono. HTTP framework integrations are now separate middleware packages: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| -------------------------------------- | ------------------------------------------- | | Built into `@modelcontextprotocol/sdk` | `@modelcontextprotocol/node` (Node.js HTTP) | -| | `@modelcontextprotocol/express` (Express) | -| | `@modelcontextprotocol/hono` (Hono) | +| | `@modelcontextprotocol/express` (Express) | +| | `@modelcontextprotocol/hono` (Hono) | Install the middleware package for your framework: @@ -128,12 +130,12 @@ This affects both transport constructors and request handler code that reads hea ```typescript // Transport headers const transport = new StreamableHTTPClientTransport(url, { - requestInit: { - headers: { - 'Authorization': 'Bearer token', - 'X-Custom': 'value', - }, - }, + requestInit: { + headers: { + Authorization: 'Bearer token', + 'X-Custom': 'value' + } + } }); // Reading headers in a request handler @@ -145,12 +147,12 @@ const sessionId = extra.requestInfo?.headers['mcp-session-id']; ```typescript // Transport headers const transport = new StreamableHTTPClientTransport(url, { - requestInit: { - headers: new Headers({ - 'Authorization': 'Bearer token', - 'X-Custom': 'value', - }), - }, + requestInit: { + headers: new Headers({ + Authorization: 'Bearer token', + 'X-Custom': 'value' + }) + } }); // Reading headers in a request handler @@ -170,22 +172,22 @@ const server = new McpServer({ name: 'demo', version: '1.0.0' }); // Tool with schema server.tool('greet', { name: z.string() }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // Tool with description server.tool('greet', 'Greet a user', { name: z.string() }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // Prompt server.prompt('summarize', { text: z.string() }, async ({ text }) => { - return { messages: [{ role: 'user', content: { type: 'text', text: `Summarize: ${text}` } }] }; + return { messages: [{ role: 'user', content: { type: 'text', text: `Summarize: ${text}` } }] }; }); // Resource -server.resource('config', 'config://app', async (uri) => { - return { contents: [{ uri: uri.href, text: '{}' }] }; +server.resource('config', 'config://app', async uri => { + return { contents: [{ uri: uri.href, text: '{}' }] }; }); ``` @@ -198,28 +200,29 @@ const server = new McpServer({ name: 'demo', version: '1.0.0' }); // Tool with schema server.registerTool('greet', { inputSchema: { name: z.string() } }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // Tool with description server.registerTool('greet', { description: 'Greet a user', inputSchema: { name: z.string() } }, async ({ name }) => { - return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; }); // Prompt server.registerPrompt('summarize', { argsSchema: { text: z.string() } }, async ({ text }) => { - return { messages: [{ role: 'user', content: { type: 'text', text: `Summarize: ${text}` } }] }; + return { messages: [{ role: 'user', content: { type: 'text', text: `Summarize: ${text}` } }] }; }); // Resource -server.registerResource('config', 'config://app', {}, async (uri) => { - return { contents: [{ uri: uri.href, text: '{}' }] }; +server.registerResource('config', 'config://app', {}, async uri => { + return { contents: [{ uri: uri.href, text: '{}' }] }; }); ``` ### Host header validation moved -Express-specific middleware (`hostHeaderValidation()`, `localhostHostValidation()`) moved from the server package to `@modelcontextprotocol/express`. The server package now exports framework-agnostic functions instead: `validateHostHeader()`, `localhostAllowedHostnames()`, `hostHeaderValidationResponse()`. +Express-specific middleware (`hostHeaderValidation()`, `localhostHostValidation()`) moved from the server package to `@modelcontextprotocol/express`. The server package now exports framework-agnostic functions instead: `validateHostHeader()`, `localhostAllowedHostnames()`, +`hostHeaderValidationResponse()`. **Before (v1):** @@ -297,6 +300,103 @@ Common method string replacements: | `ResourceListChangedNotificationSchema` | `'notifications/resources/list_changed'` | | `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | +### Context API replaces `RequestHandlerExtra` + +The `extra` parameter in tool, prompt, and resource callbacks has been replaced with a structured context object (`ctx`). The old flat `RequestHandlerExtra` type is replaced by `ServerContext` (for server callbacks) and `ClientContext` (for client callbacks). + +**Before (v1):** + +```typescript +server.registerTool('my-tool', { inputSchema: { query: z.string() } }, async ({ query }, extra) => { + // Flat properties + console.log(extra.requestId); + console.log(extra.sessionId); + console.log(extra.authInfo?.token); + + // Check cancellation + if (extra.signal.aborted) throw new Error('Cancelled'); + + // Send notification + await extra.sendNotification({ method: 'notifications/progress', params: { progress: 50 } }); + + // Task access + if (extra.taskStore) { + await extra.taskStore.updateTaskStatus(extra.taskId!, 'running'); + } + + // SSE stream control + extra.closeSSEStream?.(); + + return { content: [{ type: 'text', text: 'Done' }] }; +}); +``` + +**After (v2):** + +```typescript +server.registerTool('my-tool', { inputSchema: { query: z.string() } }, async ({ query }, ctx) => { + // Grouped into nested objects + console.log(ctx.mcpReq.id); // was extra.requestId + console.log(ctx.sessionId); // still at top level + console.log(ctx.http?.authInfo?.token); // was extra.authInfo + + // Check cancellation (signal moved to mcpReq) + if (ctx.mcpReq.signal.aborted) throw new Error('Cancelled'); + + // Send notification (via notification.send) + await ctx.notification.send({ method: 'notifications/progress', params: { progress: 50 } }); + + // Task access (grouped in task object) + if (ctx.task) { + await ctx.task.store.updateTaskStatus(ctx.task.id, 'running'); + } + + // SSE stream control (moved to http) + ctx.http?.closeSSE?.(); + + // Server-specific: logging methods on notification + await ctx.notification.info('Processing query'); + await ctx.notification.debug('Debug details', { query }); + + // Server-specific: elicitation and sampling on mcpReq + const userInput = await ctx.mcpReq.elicitInput({ message: 'Confirm?', mode: 'form', form: {} }); + const message = await ctx.mcpReq.requestSampling({ messages: [...], maxTokens: 100 }); + + return { content: [{ type: 'text', text: 'Done' }] }; +}); +``` + +#### Property mapping + +| v1 (`extra.`) | v2 (`ctx.`) | +| ------------------------------------ | ---------------------------------- | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo?.headers` | `ctx.http?.req.headers` | +| `extra.sendNotification(n)` | `ctx.notification.send(n)` | +| `extra.sendRequest(r, s, o)` | `ctx.mcpReq.send(r, s, o)` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra.closeSSEStream?.()` | `ctx.http?.closeSSE?.()` | +| `extra.closeStandaloneSSEStream?.()` | `ctx.http?.closeStandaloneSSE?.()` | + +#### Server-specific additions on context + +| Method | Description | +| ------------------------------------------- | ------------------------------------------------------ | +| `ctx.notification.log(params)` | Send logging message | +| `ctx.notification.debug(msg, data?)` | Send debug log | +| `ctx.notification.info(msg, data?)` | Send info log | +| `ctx.notification.warning(msg, data?)` | Send warning log | +| `ctx.notification.error(msg, data?)` | Send error log | +| `ctx.mcpReq.elicitInput(params, opts?)` | Request user input via elicitation | +| `ctx.mcpReq.requestSampling(params, opts?)` | Request LLM sampling from client | +| `ctx.http?.req` | Raw fetch `Request` object (access URL, headers, etc.) | + ### Client list methods return empty results for missing capabilities `Client.listPrompts()`, `listResources()`, `listResourceTemplates()`, and `listTools()` now return empty results when the server didn't advertise the corresponding capability, instead of sending the request. This respects the MCP spec's capability negotiation. @@ -307,21 +407,21 @@ To restore v1 behavior (throw an error when capabilities are missing), set `enfo const client = new Client({ name: 'my-client', version: '1.0.0' }, { enforceStrictCapabilities: true, }); -``` +```` ### Removed type aliases and deprecated exports The following deprecated type aliases have been removed from `@modelcontextprotocol/core`: -| Removed | Replacement | -|---------|-------------| -| `JSONRPCError` | `JSONRPCErrorResponse` | -| `JSONRPCErrorSchema` | `JSONRPCErrorResponseSchema` | -| `isJSONRPCError` | `isJSONRPCErrorResponse` | -| `isJSONRPCResponse` | `isJSONRPCResultResponse` | -| `ResourceReferenceSchema` | `ResourceTemplateReferenceSchema` | -| `ResourceReference` | `ResourceTemplateReference` | -| `IsomorphicHeaders` | Use Web Standard `Headers` | +| Removed | Replacement | +| ---------------------------------------- | ------------------------------------------------ | +| `JSONRPCError` | `JSONRPCErrorResponse` | +| `JSONRPCErrorSchema` | `JSONRPCErrorResponseSchema` | +| `isJSONRPCError` | `isJSONRPCErrorResponse` | +| `isJSONRPCResponse` | `isJSONRPCResultResponse` | +| `ResourceReferenceSchema` | `ResourceTemplateReferenceSchema` | +| `ResourceReference` | `ResourceTemplateReference` | +| `IsomorphicHeaders` | Use Web Standard `Headers` | | `AuthInfo` (from `server/auth/types.js`) | `AuthInfo` (now in `@modelcontextprotocol/core`) | All other types and schemas exported from `@modelcontextprotocol/sdk/types.js` retain their original names in `@modelcontextprotocol/core`. @@ -352,7 +452,8 @@ The following APIs are unchanged between v1 and v2 (only the import paths change ## Using an LLM to migrate your code -An LLM-optimized version of this guide is available at [`docs/migration-SKILL.md`](migration-SKILL.md). It contains dense mapping tables designed for tools like Claude Code to mechanically apply all the changes described above. You can paste it into your LLM context or load it as a skill. +An LLM-optimized version of this guide is available at [`docs/migration-SKILL.md`](migration-SKILL.md). It contains dense mapping tables designed for tools like Claude Code to mechanically apply all the changes described above. You can paste it into your LLM context or load it as +a skill. ## Need Help? diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index d3144c164..eea12f859 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -268,10 +268,12 @@ async function connect(url?: string): Promise { }; // Set up elicitation request handler with proper validation - client.setRequestHandler('elicitation/create', async request => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { if (request.params.mode !== 'form') { throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } + + console.log(`${ctx.mcpReq.method} elicitation request received`); console.log('\n🔔 Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index 2cc6695c7..5cbf80f01 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -46,12 +46,12 @@ const getServer = () => { cartId: z.string().describe('The ID of the cart to confirm') } }, - async ({ cartId }, extra): Promise => { + async ({ cartId }, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if the user has the provided cartId. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to confirm payment) */ - const sessionId = extra.sessionId; + const sessionId = ctx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } @@ -79,15 +79,15 @@ const getServer = () => { param1: z.string().describe('First parameter') } }, - async (_, extra): Promise => { + async (_, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if we already have a valid access token for the user. - Auth info (with a subject or `sub` claim) can be typically be found in `extra.authInfo`. + Auth info (with a subject or `sub` claim) can be typically be found in `ctx.requestCtx.authInfo`. If we do, we can just return the result of the tool call. If we don't, we can throw an ElicitationRequiredError to request the user to authenticate. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to authenticate). */ - const sessionId = extra.sessionId; + const sessionId = ctx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index fa0e3a300..3a6daf82b 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -51,7 +51,7 @@ const getServer = () => { name: z.string().describe('Name to greet') } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -59,7 +59,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.sessionId + ctx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -69,7 +69,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.sessionId + ctx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -79,7 +79,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.sessionId + ctx.sessionId ); return { diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 15669131a..cd7ef4944 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -49,7 +49,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(10) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -61,7 +61,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.sessionId + ctx.sessionId ); } catch (error) { console.error('Error sending notification:', error); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index e016c217f..8ba4a5737 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -62,7 +62,12 @@ const getServer = () => { name: z.string().describe('Name to greet') } }, - async ({ name }): Promise => { + async ({ name }, ctx): Promise => { + await ctx.notification.log({ + level: 'debug', + data: `Starting greet for ${name}` + }); + return { content: [ { @@ -88,7 +93,7 @@ const getServer = () => { openWorldHint: false } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -96,7 +101,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.sessionId + ctx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -106,7 +111,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.sessionId + ctx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -116,7 +121,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.sessionId + ctx.sessionId ); return { @@ -139,7 +144,7 @@ const getServer = () => { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') } }, - async ({ infoType }, extra): Promise => { + async ({ infoType }, ctx): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -238,8 +243,8 @@ const getServer = () => { } try { - // Use sendRequest through the extra parameter to elicit input - const result = await extra.sendRequest( + // Use sendRequest through the ctx parameter to elicit input + const result = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -302,7 +307,12 @@ const getServer = () => { name: z.string().describe('Name to include in greeting') } }, - async ({ name }): Promise => { + async ({ name }, ctx): Promise => { + await ctx.notification.log({ + level: 'debug', + data: `Starting greeting template for ${name}` + }); + return { messages: [ { @@ -327,7 +337,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(50) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -339,7 +349,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.sessionId + ctx.sessionId ); } catch (error) { console.error('Error sending notification:', error); @@ -389,7 +399,12 @@ const getServer = () => { description: 'First example file for ResourceLink demonstration', mimeType: 'text/plain' }, - async (): Promise => { + async (_, ctx): Promise => { + await ctx.notification.log({ + level: 'debug', + data: `Starting example file 1` + }); + return { contents: [ { @@ -484,10 +499,12 @@ const getServer = () => { } }, { - async createTask({ duration }, { taskStore, taskRequestedTtl }) { + async createTask({ duration }, ctx) { // Create the task + if (!ctx.task?.store) throw new Error('Task store not found'); + const taskStore = ctx.task.store; const task = await taskStore.createTask({ - ttl: taskRequestedTtl + ttl: ctx.task.requestedTtl }); // Simulate out-of-band work @@ -508,11 +525,13 @@ const getServer = () => { task }; }, - async getTask(_args, { taskId, taskStore }) { - return await taskStore.getTask(taskId); + async getTask(_args, ctx) { + if (!ctx.task?.store) throw new Error('Task store not found'); + return await ctx.task.store.getTask(ctx.task.id!); }, - async getTaskResult(_args, { taskId, taskStore }) { - const result = await taskStore.getTaskResult(taskId); + async getTaskResult(_args, ctx) { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 3a3fe80c4..9092926d9 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -507,7 +507,7 @@ const createServer = (): Server => { }); // Handle tool calls - server.setRequestHandler('tools/call', async (request, extra): Promise => { + server.setRequestHandler('tools/call', async (request, ctx): Promise => { const { name, arguments: args } = request.params; const taskParams = (request.params._meta?.task || request.params.task) as { ttl?: number; pollInterval?: number } | undefined; @@ -522,7 +522,7 @@ const createServer = (): Server => { pollInterval: taskParams.pollInterval ?? 1000 }; - const task = await taskStore.createTask(taskOptions, extra.requestId, request, extra.sessionId); + const task = await taskStore.createTask(taskOptions, ctx.mcpReq.id, request, ctx.sessionId); console.log(`\n[Server] ${name} called, task created: ${task.taskId}`); @@ -600,7 +600,7 @@ const createServer = (): Server => { activeTaskExecutions.set(task.taskId, { promise: taskExecution, server, - sessionId: extra.sessionId ?? '' + sessionId: ctx.sessionId ?? '' }); return { task }; @@ -617,10 +617,10 @@ const createServer = (): Server => { }); // Handle tasks/result - server.setRequestHandler('tasks/result', async (request, extra): Promise => { + server.setRequestHandler('tasks/result', async (request, ctx): Promise => { const { taskId } = request.params; console.log(`[Server] tasks/result called for task ${taskId}`); - return taskResultHandler.handle(taskId, server, extra.sessionId ?? ''); + return taskResultHandler.handle(taskId, server, ctx.sessionId ?? ''); }); return server; diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index 2416d2ec3..5cdfccd4b 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -7,7 +7,7 @@ * Key features: * - Configures `retryInterval` to tell clients how long to wait before reconnecting * - Uses `eventStore` to persist events for replay after reconnection - * - Uses `extra.closeSSEStream()` callback to gracefully disconnect clients mid-operation + * - Uses `ctx.requestCtx.stream.closeSSEStream()` callback to gracefully disconnect clients mid-operation * * Run with: pnpm tsx src/ssePollingExample.ts * Test with: curl or the MCP Inspector @@ -40,10 +40,11 @@ server.registerTool( { description: 'A long-running task that sends progress updates. Server will disconnect mid-task to demonstrate polling.' }, - async (extra): Promise => { + async (ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + const sessionId = ctx.sessionId; - console.log(`[${extra.sessionId}] Starting long-task...`); + console.log(`[${sessionId}] Starting long-task...`); // Send first progress notification await server.sendLoggingMessage( @@ -51,7 +52,7 @@ server.registerTool( level: 'info', data: 'Progress: 25% - Starting work...' }, - extra.sessionId + sessionId ); await sleep(1000); @@ -61,16 +62,16 @@ server.registerTool( level: 'info', data: 'Progress: 50% - Halfway there...' }, - extra.sessionId + sessionId ); await sleep(1000); // Server decides to disconnect the client to free resources // Client will reconnect via GET with Last-Event-ID after the transport's retryInterval - // Use extra.closeSSEStream callback - available when eventStore is configured - if (extra.closeSSEStream) { - console.log(`[${extra.sessionId}] Closing SSE stream to trigger client polling...`); - extra.closeSSEStream(); + // Use ctx.stream.closeSSE callback - available when eventStore is configured + if (ctx.http?.closeSSE) { + console.log(`[${sessionId}] Closing SSE stream to trigger client polling...`); + ctx.http.closeSSE(); } // Continue processing while client is disconnected @@ -81,7 +82,7 @@ server.registerTool( level: 'info', data: 'Progress: 75% - Almost done (sent while client disconnected)...' }, - extra.sessionId + sessionId ); await sleep(500); @@ -90,10 +91,10 @@ server.registerTool( level: 'info', data: 'Progress: 100% - Complete!' }, - extra.sessionId + sessionId ); - console.log(`[${extra.sessionId}] Task complete`); + console.log(`[${sessionId}] Task complete`); return { content: [ diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 9c2c87290..220851a88 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -1,4 +1,6 @@ import type { + AnySchema, + BaseContext, CallToolRequest, ClientCapabilities, ClientNotification, @@ -8,6 +10,7 @@ import type { CompleteRequest, GetPromptRequest, Implementation, + JSONRPCRequest, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, @@ -18,18 +21,23 @@ import type { ListResourceTemplatesRequest, ListToolsRequest, LoggingLevel, + MessageExtraInfo, Notification, NotificationMethod, + NotificationOptions, ProtocolOptions, ReadResourceRequest, Request, - RequestHandlerExtra, RequestMethod, RequestOptions, RequestTypeMap, Result, + SchemaOutput, ServerCapabilities, SubscribeRequest, + TaskContext, + TaskCreationParams, + TaskStore, Tool, Transport, UnsubscribeRequest @@ -65,6 +73,7 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import type { ClientContext } from './context.js'; /** * Elicitation default application helper. Applies defaults to the data based on the schema. @@ -332,13 +341,13 @@ export class Client< method: M, handler: ( request: RequestTypeMap[M], - extra: RequestHandlerExtra + ctx: ClientContext ) => ClientResult | ResultT | Promise ): void { if (method === 'elicitation/create') { const wrappedHandler = async ( request: RequestTypeMap[M], - extra: RequestHandlerExtra + ctx: ClientContext ): Promise => { const validatedRequest = safeParse(ElicitRequestSchema, request); if (!validatedRequest.success) { @@ -360,7 +369,7 @@ export class Client< throw new McpError(ErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); } - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -411,7 +420,7 @@ export class Client< if (method === 'sampling/createMessage') { const wrappedHandler = async ( request: RequestTypeMap[M], - extra: RequestHandlerExtra + ctx: ClientContext ): Promise => { const validatedRequest = safeParse(CreateMessageRequestSchema, request); if (!validatedRequest.success) { @@ -422,7 +431,7 @@ export class Client< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -458,6 +467,76 @@ export class Client< return super.setRequestHandler(method, handler); } + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ClientContext { + const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + const sessionId = capturedTransport?.sessionId; + + // Build the task context using the helper from Protocol + const task: TaskContext | undefined = this.buildTaskContext({ + taskStore, + request, + sessionId, + relatedTaskId, + taskCreationParams + }); + + // Closure helpers for sendRequest and sendNotification + const sendRequest = async ( + req: ClientRequest | RequestT, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + const taskId = task?.id; + if (taskId) { + requestOptions.relatedTask = { taskId }; + if (task?.store) { + await task.store.updateTaskStatus(taskId, 'input_required'); + } + } + return await this.request(req, resultSchema, requestOptions); + }; + + const sendNotification = async (notification: ClientNotification | NotificationT): Promise => { + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (task && task.id) { + notificationOptions.relatedTask = { taskId: task.id }; + } + return this.notification(notification, notificationOptions); + }; + + // Return a ClientContext POJO — use BaseContext here since ClientContext already unions with Client types + const ctx: BaseContext = { + sessionId, + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta, + signal: abortController.signal, + send: sendRequest + }, + http: extra?.authInfo + ? { + authInfo: extra.authInfo + } + : undefined, + task, + notification: { + send: sendNotification + } + }; + + return ctx; + } + protected assertCapability(capability: keyof ServerCapabilities, method: string): void { if (!this._serverCapabilities?.[capability]) { throw new Error(`Server does not support ${capability} (required for ${method})`); diff --git a/packages/client/src/client/context.ts b/packages/client/src/client/context.ts new file mode 100644 index 000000000..b6185b181 --- /dev/null +++ b/packages/client/src/client/context.ts @@ -0,0 +1,13 @@ +import type { BaseContext, ClientNotification, ClientRequest, Notification, Request } from '@modelcontextprotocol/core'; + +/** + * Client-specific context type for request handlers. + * Used when the client handles requests from the server (e.g., sampling, elicitation). + * + * @typeParam RequestT - Additional request types beyond ClientRequest + * @typeParam NotificationT - Additional notification types beyond ClientNotification + */ +export type ClientContext = BaseContext< + ClientRequest | RequestT, + ClientNotification | NotificationT +>; diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index c1901d70a..4bf11942c 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,6 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, @@ -12,8 +11,6 @@ import type { Request, RequestId, Result, - ServerNotification, - ServerRequest, Task, ToolExecution } from '../../types/types.js'; @@ -22,23 +19,6 @@ import type { // Task Handler Types (for registerToolTask) // ============================================================================ -/** - * Extended handler extra with task store for task creation. - * @experimental - */ -export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { - taskStore: RequestTaskStore; -} - -/** - * Extended handler extra with task ID and store for task operations. - * @experimental - */ -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { - taskId: string; - taskStore: RequestTaskStore; -} - /** * Task-specific execution configuration. * taskSupport cannot be 'forbidden' for task-based tools. diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index b7980fadb..ee71063db 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,6 +1,7 @@ export * from './auth/errors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/context.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; @@ -9,6 +10,7 @@ export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; export * from './types/types.js'; +export * from './types/utility.js'; export * from './util/inMemory.js'; export * from './util/zodCompat.js'; export * from './util/zodJsonSchemaCompat.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts new file mode 100644 index 000000000..23b99501a --- /dev/null +++ b/packages/core/src/shared/context.ts @@ -0,0 +1,125 @@ +import type { CreateTaskOptions } from '../experimental/tasks/interfaces.js'; +import type { AuthInfo, Notification, Request, RequestId, RequestMeta, Result, Task } from '../types/types.js'; +import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; +import type { RequestOptions } from './protocol.js'; + +/** + * Request-scoped task store for managing task state within a handler. + */ +export interface RequestTaskStore { + /** + * Creates a new task with the given creation parameters. + */ + createTask(taskParams: CreateTaskOptions): Promise; + + /** + * Gets the current status of a task. + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a task and sets its final status. + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; + + /** + * Retrieves the stored result of a task. + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + +/** + * Task-related context for task-augmented requests. + */ +export type TaskContext = { + /** + * The ID of the task. + */ + id: string; + /** + * The task store for managing task state. + */ + store: RequestTaskStore; + /** + * The requested TTL for the task, or null if not specified. + */ + requestedTtl: number | null; +}; + +/** + * Base context interface for request handlers. + * Defines the common structure shared by both client and server contexts. + * + * @typeParam RequestT - The type of requests that can be sent from this context + * @typeParam NotificationT - The type of notifications that can be sent from this context + */ +export type BaseContext = { + /** + * The session ID of the request. + */ + sessionId?: string; + + /** + * MCP request context containing protocol-level information. + */ + mcpReq: { + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + id: RequestId; + /** + * The method of the request. + */ + method: string; + /** + * The metadata of the request. + */ + _meta?: RequestMeta; + /** + * An abort signal used to communicate if the request was cancelled. + */ + signal: AbortSignal; + /** + * Sends a request that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + */ + send: (request: RequestT, resultSchema: U, options?: RequestOptions) => Promise>; + }; + + /** + * HTTP request context with authentication information. + */ + http?: { + /** + * The authentication information, if available. + */ + authInfo?: AuthInfo; + }; + + /** + * Task context if this is a task-augmented request, undefined otherwise. + */ + task: TaskContext | undefined; + + /** + * Notification context with send method. + */ + notification: { + /** + * Sends a notification that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + */ + send: (notification: NotificationT) => Promise; + }; +}; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index f537aa86c..b82134dec 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,7 +1,6 @@ -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; +import type { QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; import { isTerminal } from '../experimental/tasks/interfaces.js'; import type { - AuthInfo, CancelledNotification, ClientCapabilities, GetTaskPayloadRequest, @@ -12,7 +11,6 @@ import type { JSONRPCRequest, JSONRPCResponse, JSONRPCResultResponse, - MessageExtraInfo, Notification, NotificationMethod, NotificationTypeMap, @@ -21,15 +19,11 @@ import type { RelatedTaskMetadata, Request, RequestId, - RequestInfo, - RequestMeta, RequestMethod, RequestTypeMap, Result, ServerCapabilities, - Task, - TaskCreationParams, - TaskStatusNotification + TaskCreationParams } from '../types/types.js'; import { CancelTaskResultSchema, @@ -45,12 +39,13 @@ import { isTaskAugmentedRequestParams, ListTasksResultSchema, McpError, - RELATED_TASK_META_KEY, - TaskStatusNotificationSchema + RELATED_TASK_META_KEY } from '../types/types.js'; +import type { MessageExtraInfo } from '../types/utility.js'; import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; import { safeParse } from '../util/zodCompat.js'; import { parseWithCompat } from '../util/zodJsonSchemaCompat.js'; +import type { BaseContext, RequestTaskStore, TaskContext } from './context.js'; import type { ResponseMessage } from './responseMessage.js'; import type { Transport, TransportSendOptions } from './transport.js'; @@ -176,134 +171,6 @@ export type NotificationOptions = { */ // relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. export type TaskRequestOptions = Omit; - -/** - * Request-scoped TaskStore interface. - */ -export interface RequestTaskStore { - /** - * Creates a new task with the given creation parameters. - * The implementation generates a unique taskId and createdAt timestamp. - * - * @param taskParams - The task creation parameters from the request - * @returns The created task object - */ - createTask(taskParams: CreateTaskOptions): Promise; - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @returns The task object - * @throws If the task does not exist - */ - getTask(taskId: string): Promise; - - /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: 'completed' for success, 'failed' for errors - * @param result - The result to store - */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @returns The stored result - */ - getTaskResult(taskId: string): Promise; - - /** - * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @returns An object containing the tasks array and an optional nextCursor - */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; -} - -/** - * Extra data given to request handlers. - */ -export type RequestHandlerExtra = { - /** - * An abort signal used to communicate if the request was cancelled from the sender's side. - */ - signal: AbortSignal; - - /** - * Information about a validated access token, provided to request handlers. - */ - authInfo?: AuthInfo; - - /** - * The session ID from the transport, if available. - */ - sessionId?: string; - - /** - * Metadata from the original request. - */ - _meta?: RequestMeta; - - /** - * The JSON-RPC ID of the request being handled. - * This can be useful for tracking or logging purposes. - */ - requestId: RequestId; - - taskId?: string; - - taskStore?: RequestTaskStore; - - taskRequestedTtl?: number | null; - - /** - * The original HTTP request. - */ - requestInfo?: RequestInfo; - - /** - * Sends a notification that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - sendNotification: (notification: SendNotificationT) => Promise; - - /** - * Sends a request that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - sendRequest: (request: SendRequestT, resultSchema: U, options?: TaskRequestOptions) => Promise>; - - /** - * Closes the SSE stream for this request, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior during long-running operations. - */ - closeSSEStream?: () => void; - - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior for server-initiated notifications. - */ - closeStandaloneSSEStream?: () => void; -}; - /** * Information about a request's timeout state */ @@ -325,7 +192,7 @@ export abstract class Protocol) => Promise + (request: JSONRPCRequest, extra: BaseContext) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); @@ -359,7 +226,7 @@ export abstract class Protocol) => Promise; + fallbackRequestHandler?: (request: JSONRPCRequest, extra: BaseContext) => Promise; /** * A handler to invoke for any notification types that do not have their own handler installed. @@ -385,8 +252,8 @@ export abstract class Protocol { - const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + this.setRequestHandler('tasks/get', async (request, ctx) => { + const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } @@ -398,14 +265,14 @@ export abstract class Protocol { + this.setRequestHandler('tasks/result', async (request, ctx) => { const handleTaskResult = async (): Promise => { const taskId = request.params.taskId; // Deliver queued messages if (this._taskMessageQueue) { let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, extra.sessionId))) { + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, ctx.sessionId))) { // Handle response and error messages by routing them to the appropriate resolver if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { const message = queuedMessage.message; @@ -443,12 +310,12 @@ export abstract class Protocol { + this.setRequestHandler('tasks/list', async (request, ctx) => { try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, ctx.sessionId); return { tasks, nextCursor, @@ -501,10 +368,10 @@ export abstract class Protocol { + this.setRequestHandler('tasks/cancel', async (request, ctx) => { try { // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); @@ -519,12 +386,12 @@ export abstract class Protocol = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: async notification => { - // Include related-task metadata if this request is part of a task - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - }, - sendRequest: async (r, resultSchema, options?) => { - // Include related-task metadata if this request is part of a task - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - // Set task status to input_required when sending a request within a task context - // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await this.request(r, resultSchema, requestOptions); - }, - authInfo: extra?.authInfo, - requestId: request.id, - requestInfo: extra?.requestInfo, - taskId: relatedTaskId, - taskStore: taskStore, - taskRequestedTtl: taskCreationParams?.ttl, - closeSSEStream: extra?.closeSSEStream, - closeStandaloneSSEStream: extra?.closeStandaloneSSEStream - }; + const ctx: BaseContext = this.createRequestContext({ + request, + taskStore: this._taskStore, + relatedTaskId, + taskCreationParams, + abortController, + capturedTransport, + extra + }); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() @@ -762,7 +600,7 @@ export abstract class Protocol handler(request, fullExtra)) + .then(() => handler(request, ctx)) .then( async result => { if (abortController.signal.aborted) { @@ -825,6 +663,77 @@ export abstract class Protocol { + const task = await taskStore.createTask(taskParams, request.id, request, sessionId); + currentTaskId = task.taskId; + return task; + }, + getTask: async taskId => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + return task; + }, + storeTaskResult: (taskId, status, result) => { + return taskStore.storeTaskResult(taskId, status, result, sessionId); + }, + getTaskResult: taskId => { + return taskStore.getTaskResult(taskId, sessionId); + }, + updateTaskStatus: (taskId, status, statusMessage) => { + return taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); + }, + listTasks: cursor => { + return taskStore.listTasks(sessionId, cursor); + } + }; + + return { + get id() { + return currentTaskId; + }, + store, + requestedTtl: taskCreationParams?.ttl ?? null + }; + } + + /** + * Creates the context object passed to request handlers. + * Subclasses must implement this to provide the appropriate context type + * (ClientContext for Client, ServerContext for Server). + */ + protected abstract createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): BaseContext; + private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); @@ -1200,9 +1109,9 @@ export abstract class Protocol { const handler = this._responseHandlers.get(messageId); @@ -1227,7 +1136,8 @@ export abstract class Protocol { this._cleanupTimeout(messageId); reject(error); @@ -1290,9 +1200,9 @@ export abstract class Protocol( method: M, - handler: ( - request: RequestTypeMap[M], - extra: RequestHandlerExtra - ) => SendResultT | Promise + handler: (request: RequestTypeMap[M], ctx: BaseContext) => SendResultT | Promise ): void { this.assertRequestHandlerCapability(method); const schema = getRequestSchema(method); - this._requestHandlers.set(method, (request, extra) => { + this._requestHandlers.set(method, (request, ctx) => { const parsed = parseWithCompat(schema, request) as RequestTypeMap[M]; - return Promise.resolve(handler(parsed, extra)); + return Promise.resolve(handler(parsed, ctx)); }); } @@ -1552,95 +1459,6 @@ export abstract class Protocol { - if (!request) { - throw new Error('No request provided'); - } - - return await taskStore.createTask( - taskParams, - request.id, - { - method: request.method, - params: request.params - }, - sessionId - ); - }, - getTask: async taskId => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - return task; - }, - storeTaskResult: async (taskId, status, result) => { - await taskStore.storeTaskResult(taskId, status, result, sessionId); - - // Get updated task state and send notification - const task = await taskStore.getTask(taskId, sessionId); - if (task) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: task - }); - await this.notification(notification as SendNotificationT); - - if (isTerminal(task.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - getTaskResult: taskId => { - return taskStore.getTaskResult(taskId, sessionId); - }, - updateTaskStatus: async (taskId, status, statusMessage) => { - // Check if task exists - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); - } - - // Don't allow transitions from terminal states - if (isTerminal(task.status)) { - throw new McpError( - ErrorCode.InvalidParams, - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - - // Get updated task state and send notification - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await this.notification(notification as SendNotificationT); - - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - listTasks: cursor => { - return taskStore.listTasks(cursor, sessionId); - } - }; - } } function isPlainObject(value: unknown): value is Record { diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index 87608f124..e830a4872 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -1,4 +1,5 @@ -import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/types.js'; +import type { JSONRPCMessage, RequestId } from '../types/types.js'; +import type { MessageExtraInfo } from '../types/utility.js'; export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index 9d659693f..14a51b761 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2360,43 +2360,6 @@ type Flatten = T extends Primitive type Infer = Flatten>; -/** - * Information about the incoming request. - */ -export interface RequestInfo { - /** - * The headers of the request. - */ - headers: Headers; -} - -/** - * Extra information about a message. - */ -export interface MessageExtraInfo { - /** - * The request information. - */ - requestInfo?: RequestInfo; - - /** - * The authentication information. - */ - authInfo?: AuthInfo; - - /** - * Callback to close the SSE stream for this request, triggering client reconnection. - * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. - */ - closeSSEStream?: () => void; - - /** - * Callback to close the standalone GET SSE stream, triggering client reconnection. - * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. - */ - closeStandaloneSSEStream?: () => void; -} - /* JSON-RPC types */ export type ProgressToken = Infer; export type Cursor = Infer; diff --git a/packages/core/src/types/utility.ts b/packages/core/src/types/utility.ts new file mode 100644 index 000000000..d33d3222c --- /dev/null +++ b/packages/core/src/types/utility.ts @@ -0,0 +1,29 @@ +import type { AuthInfo } from './types.js'; + +/** + * Extra information about a message. + */ +export interface MessageExtraInfo { + /** + * The raw Request object (fetch API Request). + * Provides access to url, headers, and other request properties. + */ + request?: Request; + + /** + * The authentication information. + */ + authInfo?: AuthInfo; + + /** + * Callback to close the SSE stream for this request, triggering client reconnection. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. + */ + closeSSEStream?: () => void; + + /** + * Callback to close the standalone GET SSE stream, triggering client reconnection. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. + */ + closeStandaloneSSEStream?: () => void; +} diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index b5f6f40cb..c6abd4070 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -11,6 +11,7 @@ import type { TaskStore } from '../../src/experimental/tasks/interfaces.js'; import { InMemoryTaskMessageQueue } from '../../src/experimental/tasks/stores/inMemory.js'; +import type { BaseContext } from '../../src/shared/context.js'; import { mergeCapabilities, Protocol } from '../../src/shared/protocol.js'; import type { ErrorMessage, ResponseMessage } from '../../src/shared/responseMessage.js'; import { toArrayAsync } from '../../src/shared/responseMessage.js'; @@ -30,6 +31,7 @@ import type { TaskCreationParams } from '../../src/types/types.js'; import { ErrorCode, McpError, RELATED_TASK_META_KEY } from '../../src/types/types.js'; +import type { MessageExtraInfo } from '../../src/types/utility.js'; // Type helper for accessing private/protected Protocol properties in tests interface TestProtocol { @@ -38,7 +40,7 @@ interface TestProtocol { _responseHandlers: Map void>; _taskProgressTokens: Map; _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; - requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; + _cleanupTaskProgressHandler: (taskId: string) => void; // Protected task methods (exposed for testing) listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; cancelTask: (params: { taskId: string }) => Promise; @@ -66,7 +68,7 @@ function createMockTaskStore(options?: { return { createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { // Generate a unique task ID - const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const taskId = `test-task-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`; const createdAt = new Date().toISOString(); const task = (tasks[taskId] = { taskId, @@ -147,6 +149,70 @@ function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { expect(o?.type).toBe('request'); } +/** + * Creates a mock Protocol class for testing with all abstract methods implemented. + */ +function createTestProtocolClass(options?: { + taskStore?: TaskStore; + taskMessageQueue?: TaskMessageQueue; + debouncedNotificationMethods?: string[]; + maxTaskQueueSize?: number; + defaultTaskPollInterval?: number; +}) { + return class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): BaseContext { + // Create a context that properly delegates to the protocol + const sessionId = args.capturedTransport?.sessionId; + const requestId = args.request.id; + + return { + sessionId, + mcpReq: { + id: requestId, + method: args.request.method, + _meta: args.request.params?._meta, + signal: args.abortController.signal, + send: async () => ({}) as never + }, + http: { + authInfo: undefined + }, + task: undefined, + notification: { + send: async (notification: Notification) => { + // Properly delegate to the protocol's notification method with relatedTask metadata + const notificationOptions: { relatedRequestId?: RequestId; relatedTask?: { taskId: string } } = { + relatedRequestId: requestId + }; + // Extract relatedTask from the original request's _meta if present + if (args.relatedTaskId) { + notificationOptions.relatedTask = { taskId: args.relatedTaskId }; + } + await this.notification(notification, notificationOptions); + } + } + }; + } + + constructor() { + super(options); + } + }; +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -155,13 +221,7 @@ describe('protocol tests', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); }); test('should throw a timeout error if the request exceeds the timeout', async () => { @@ -629,13 +689,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has parameters', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced_with_params'] }))(); await protocol.connect(transport); // ACT @@ -652,13 +706,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has a relatedRequestId', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced_with_options'] }))(); await protocol.connect(transport); // ACT @@ -673,13 +721,7 @@ describe('protocol tests', () => { it('should clear pending debounced notifications on connection close', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -699,13 +741,7 @@ describe('protocol tests', () => { it('should debounce multiple synchronous calls when params property is omitted', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -728,13 +764,7 @@ describe('protocol tests', () => { it('should debounce calls when params is explicitly undefined', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -755,13 +785,7 @@ describe('protocol tests', () => { it('should send non-debounced notifications immediately and multiple times', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); // Configure for a different method await protocol.connect(transport); // ACT @@ -790,13 +814,7 @@ describe('protocol tests', () => { it('should handle sequential batches of debounced notifications correctly', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT (Batch 1) @@ -1009,13 +1027,7 @@ describe('Task-based execution', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('request with task metadata', () => { @@ -1034,7 +1046,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 30000, + ttl: 30_000, pollInterval: 1000 } }) @@ -1048,7 +1060,7 @@ describe('Task-based execution', () => { params: { name: 'test-tool', task: { - ttl: 30000, + ttl: 30_000, pollInterval: 1000 } } @@ -1077,7 +1089,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 60000 + ttl: 60_000 } }) .catch(() => { @@ -1092,7 +1104,7 @@ describe('Task-based execution', () => { customField: 'customValue' }, task: { - ttl: 60000 + ttl: 60_000 } } }), @@ -1114,7 +1126,7 @@ describe('Task-based execution', () => { const resultPromise = protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 } }); @@ -1204,7 +1216,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, relatedTask: { @@ -1232,7 +1244,7 @@ describe('Task-based execution', () => { expect(queuedMessage.message.params).toMatchObject({ name: 'test-tool', task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, _meta: { @@ -1256,20 +1268,14 @@ describe('Task-based execution', () => { // rather than in _meta, and that task management is handled by tool implementors const mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); protocol.setRequestHandler('tools/call', async request => { // Tool implementor can access task creation parameters from request.params.task expect(request.params.task).toEqual({ - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }); return { content: [{ type: 'text', text: 'success' }] }; @@ -1283,7 +1289,7 @@ describe('Task-based execution', () => { name: 'test', arguments: {}, task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 } } @@ -1315,7 +1321,7 @@ describe('Task-based execution', () => { const task2 = await mockTaskStore.createTask( { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, 2, @@ -1325,13 +1331,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1361,7 +1361,7 @@ describe('Task-based execution', () => { { taskId: task2.taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: expect.any(String), lastUpdatedAt: expect.any(String), pollInterval: 1000 @@ -1386,13 +1386,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1432,13 +1426,7 @@ describe('Task-based execution', () => { onList: () => listedTasks.releaseLatch() }); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1465,13 +1453,7 @@ describe('Task-based execution', () => { const mockTaskStore = createMockTaskStore(); mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1492,7 +1474,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(4); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Failed to list tasks'); expect(sentMessage.error.message).toContain('Invalid cursor'); }); @@ -1552,7 +1534,7 @@ describe('Task-based execution', () => { { taskId: 'task-11', status: 'working', - ttl: 30000, + ttl: 30_000, createdAt: '2024-01-01T00:00:00Z', lastUpdatedAt: '2024-01-01T00:00:00Z', pollInterval: 1000 @@ -1599,13 +1581,7 @@ describe('Task-based execution', () => { throw new Error('Task not found'); }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1641,13 +1617,7 @@ describe('Task-based execution', () => { mockTaskStore.getTask.mockResolvedValue(null); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1671,7 +1641,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(6); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Task not found'); }); @@ -1689,13 +1659,7 @@ describe('Task-based execution', () => { mockTaskStore.updateTaskStatus.mockClear(); mockTaskStore.getTask.mockResolvedValue(completedTask); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1719,7 +1683,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(7); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Cannot cancel task in terminal status'); }); @@ -1737,7 +1701,7 @@ describe('Task-based execution', () => { _meta: {}, taskId: 'task-to-delete', status: 'cancelled', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString() } @@ -1771,13 +1735,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); await serverProtocol.connect(serverTransport); @@ -1806,7 +1764,7 @@ describe('Task-based execution', () => { // Verify that getTask was called after updateTaskStatus // This is done by the RequestTaskStore wrapper to get the updated task for the notification const getTaskCalls = mockTaskStore.getTask.mock.calls; - const lastGetTaskCall = getTaskCalls[getTaskCalls.length - 1]; + const lastGetTaskCall = getTaskCalls.at(-1); expect(lastGetTaskCall?.[0]).toBe(task.taskId); }); }); @@ -1821,13 +1779,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1870,13 +1822,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1907,13 +1853,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1952,13 +1892,7 @@ describe('Task-based execution', () => { await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1995,13 +1929,10 @@ describe('Task-based execution', () => { it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { const mockTaskStore = createMockTaskStore(); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const serverProtocol = new (createTestProtocolClass({ + taskStore: mockTaskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2009,9 +1940,9 @@ describe('Task-based execution', () => { await serverProtocol.connect(serverTransport); // Set up a handler that uses sendRequest and sendNotification - serverProtocol.setRequestHandler('tools/call', async (_request, extra) => { - // Send a notification using the extra.sendNotification - await extra.sendNotification({ + serverProtocol.setRequestHandler('tools/call', async (_request, ctx) => { + // Send a notification using the ctx.notification.send + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info', data: 'test' } }); @@ -2078,13 +2009,7 @@ describe('Request Cancellation vs Task Cancellation', () => { beforeEach(() => { transport = new MockTransport(); taskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + protocol = new (createTestProtocolClass({ taskStore }))(); }); describe('notifications/cancelled behavior', () => { @@ -2093,11 +2018,11 @@ describe('Request Cancellation vs Task Cancellation', () => { // Set up a request handler that checks if it was aborted let wasAborted = false; - protocol.setRequestHandler('ping', async (_request, extra) => { + protocol.setRequestHandler('ping', async (_request, ctx) => { // Simulate a long-running operation await new Promise(resolve => setTimeout(resolve, 100)); - wasAborted = extra.signal.aborted; - return {}; + wasAborted = ctx.mcpReq.signal.aborted; + return { _meta: {} } as Result; }); // Simulate an incoming request @@ -2137,7 +2062,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2169,7 +2094,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2203,7 +2128,7 @@ describe('Request Cancellation vs Task Cancellation', () => { const sendSpy = vi.spyOn(transport, 'send'); // Create a task and mark it as completed - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2275,7 +2200,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2311,7 +2236,7 @@ describe('Request Cancellation vs Task Cancellation', () => { }); // Create a task (simulating a long-running tools/call) - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'tools/call', params: { name: 'long-running-tool', arguments: {} } }); @@ -2363,24 +2288,12 @@ describe('Progress notification support for tasks', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); }); it('should maintain progress token association after CreateTaskResult is returned', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2404,7 +2317,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol .request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }) .catch(() => { @@ -2431,7 +2344,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2467,26 +2380,20 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (completed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); await protocol.connect(transport); // Set up a request handler that will complete the task - protocol.setRequestHandler('tools/call', async (_request, extra) => { - if (extra.taskStore) { - const task = await extra.taskStore.createTask({ ttl: 60000 }); + protocol.setRequestHandler('tools/call', async (_request, ctx) => { + if (ctx.task?.store) { + const task = await ctx.task.store.createTask({ ttl: 60_000 }); // Simulate async work then complete the task setTimeout(async () => { - await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Done' }] }); }, 50); @@ -2514,7 +2421,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol .request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }) .catch(() => { @@ -2529,7 +2436,7 @@ describe('Progress notification support for tasks', () => { const progressToken = sentRequest.params._meta.progressToken; // Create a task in the mock store first so it exists when we try to get it later - const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); + const createdTask = await taskStore.createTask({ ttl: 60_000 }, messageId, request); const taskId = createdTask.taskId; // Simulate CreateTaskResult response @@ -2568,11 +2475,12 @@ describe('Progress notification support for tasks', () => { expect(taskProgressTokens.has(taskId)).toBe(true); expect(taskProgressTokens.get(taskId)).toBe(progressToken); - // Simulate task completion by calling through the protocol's task store - // This will trigger the cleanup logic - const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); - await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + // Simulate task completion by updating task store and triggering cleanup + // First update the task status in the store + await taskStore.storeTaskResult(taskId, 'completed', { content: [] }); + + // Then manually trigger the cleanup (in real usage, this happens via tasks/result polling) + (protocol as unknown as TestProtocol)._cleanupTaskProgressHandler(taskId); // Wait for all async operations including notification sending to complete await new Promise(resolve => setTimeout(resolve, 50)); @@ -2602,13 +2510,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (failed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2630,7 +2532,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2648,7 +2550,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2671,7 +2573,7 @@ describe('Progress notification support for tasks', () => { params: { taskId, status: 'failed', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(), statusMessage: 'Task failed' @@ -2700,13 +2602,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task is cancelled', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2728,7 +2624,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2746,7 +2642,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2766,7 +2662,7 @@ describe('Progress notification support for tasks', () => { params: { taskId, status: 'cancelled', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(), statusMessage: 'User cancelled' @@ -2795,13 +2691,7 @@ describe('Progress notification support for tasks', () => { it('should use the same progressToken throughout task lifetime', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2823,7 +2713,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2841,7 +2731,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2895,7 +2785,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 60000 + ttl: 60_000 }, onprogress: onProgressMock }); @@ -2920,7 +2810,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 }, onprogress: onProgressMock }); @@ -2970,7 +2860,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 }, onprogress: onProgressMock }); @@ -2987,7 +2877,7 @@ describe('Progress notification support for tasks', () => { task: { taskId: 'task-123', status: 'working', - ttl: 30000, + ttl: 30_000, createdAt: new Date().toISOString() } } @@ -3064,18 +2954,12 @@ describe('Message interception for task-related notifications', () => { it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a notification with related task metadata await server.notification( @@ -3101,13 +2985,7 @@ describe('Message interception for task-related notifications', () => { it('should not queue notifications without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3128,18 +3006,16 @@ describe('Message interception for task-related notifications', () => { it('should propagate queue overflow errors without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = new (createTestProtocolClass({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue(), + maxTaskQueueSize: 100 + }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Fill the queue to max capacity (100 messages) for (let i = 0; i < 100; i++) { @@ -3175,13 +3051,7 @@ describe('Message interception for task-related notifications', () => { it('should extract task ID correctly from metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3208,18 +3078,12 @@ describe('Message interception for task-related notifications', () => { it('should preserve message order when queuing multiple notifications', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send multiple notifications for (let i = 0; i < 5; i++) { @@ -3250,18 +3114,12 @@ describe('Message interception for task-related requests', () => { it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata (don't await - we're testing queuing) const requestPromise = server.request( @@ -3302,13 +3160,7 @@ describe('Message interception for task-related requests', () => { it('should not queue requests without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3341,18 +3193,12 @@ describe('Message interception for task-related requests', () => { it('should store request resolver for response routing', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata const requestPromise = server.request( @@ -3394,18 +3240,12 @@ describe('Message interception for task-related requests', () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); const queue = new InMemoryTaskMessageQueue(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: queue }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: queue }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata const requestPromise = server.request( @@ -3459,13 +3299,7 @@ describe('Message interception for task-related requests', () => { it('should log error when resolver is missing for side-channeled request', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); const errors: Error[] = []; server.onerror = (error: Error) => { @@ -3475,7 +3309,7 @@ describe('Message interception for task-related requests', () => { await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata void server.request( @@ -3535,18 +3369,16 @@ describe('Message interception for task-related requests', () => { it('should propagate queue overflow errors for requests without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = new (createTestProtocolClass({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue(), + maxTaskQueueSize: 100 + }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Fill the queue to max capacity (100 messages) const promises: Promise[] = []; @@ -3596,13 +3428,7 @@ describe('Message Interception', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('messages with relatedTask metadata are queued', () => { @@ -4114,13 +3940,7 @@ describe('Queue lifecycle management', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('queue cleanup on task completion', () => { @@ -4238,7 +4058,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify request is queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4311,7 +4131,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify request is queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4344,19 +4164,19 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request1', params: { data: 'test1' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); const request2Promise = protocol .request({ method: 'test/request2', params: { data: 'test2' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); const request3Promise = protocol .request({ method: 'test/request3', params: { data: 'test3' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify requests are queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4394,7 +4214,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Get the request ID that was sent const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; @@ -4428,13 +4248,7 @@ describe('requestStream() method', () => { test('should yield result immediately for non-task requests', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Start the request stream @@ -4471,13 +4285,7 @@ describe('requestStream() method', () => { test('should yield error message on request failure', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Start the request stream @@ -4517,13 +4325,7 @@ describe('requestStream() method', () => { test('should handle cancellation via AbortSignal', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const abortController = new AbortController(); @@ -4555,13 +4357,7 @@ describe('requestStream() method', () => { describe('Error responses', () => { test('should yield error as terminal message for server error response', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4587,7 +4383,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.message).toContain('Server error'); @@ -4597,13 +4393,7 @@ describe('requestStream() method', () => { vi.useFakeTimers(); try { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4624,7 +4414,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.code).toBe(ErrorCode.RequestTimeout); @@ -4635,13 +4425,7 @@ describe('requestStream() method', () => { test('should yield error as terminal message for cancellation', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const abortController = new AbortController(); @@ -4660,7 +4444,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.message).toContain('cancelled'); @@ -4668,13 +4452,7 @@ describe('requestStream() method', () => { test('should not yield any messages after error message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4720,13 +4498,7 @@ describe('requestStream() method', () => { test('should yield error as terminal message for task failure', async () => { const transport = new MockTransport(); const mockTaskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4773,20 +4545,14 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); }); test('should yield error as terminal message for network error', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Override send to simulate network error @@ -4801,20 +4567,14 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); }); test('should ensure error is always the final message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4840,7 +4600,7 @@ describe('requestStream() method', () => { // Verify error is the last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); expect(lastMessage?.type).toBe('error'); // Verify all messages before the last are not terminal @@ -4864,17 +4624,11 @@ describe('Error handling for missing resolvers', () => { taskMessageQueue = new InMemoryTaskMessageQueue(); errorHandler = vi.fn(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ + protocol = new (createTestProtocolClass({ taskStore, taskMessageQueue, defaultTaskPollInterval: 100 - }); + }))(); // @ts-expect-error deliberately overriding error handler with mock protocol.onerror = errorHandler; @@ -4886,7 +4640,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a response message without a corresponding resolver await taskMessageQueue.enqueue(task.taskId, { @@ -4931,7 +4685,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a response with missing resolver, then a valid notification await taskMessageQueue.enqueue(task.taskId, { @@ -4972,7 +4726,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a request without storing a resolver await taskMessageQueue.enqueue(task.taskId, { @@ -5002,7 +4756,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5042,7 +4796,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; @@ -5146,7 +4900,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw when processing response with missing resolver', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'response', @@ -5178,7 +4932,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw during task cleanup with missing resolvers', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'request', @@ -5202,7 +4956,7 @@ describe('Error handling for missing resolvers', () => { it('should route error messages to resolvers correctly', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5255,7 +5009,7 @@ describe('Error handling for missing resolvers', () => { it('should log error for unknown request ID in error messages', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue an error message without a corresponding resolver await taskMessageQueue.enqueue(task.taskId, { @@ -5299,7 +5053,7 @@ describe('Error handling for missing resolvers', () => { it('should handle error messages with data field', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5348,7 +5102,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw when processing error with missing resolver', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'error', @@ -5385,7 +5139,7 @@ describe('Error handling for missing resolvers', () => { it('should handle mixed response and error messages in queue', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; // Set up resolvers for multiple requests @@ -5471,7 +5225,7 @@ describe('Error handling for missing resolvers', () => { it('should maintain FIFO order when processing responses and errors', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; const callOrder: number[] = []; @@ -5495,7 +5249,7 @@ describe('Error handling for missing resolvers', () => { message: { jsonrpc: '2.0', id: 2, - error: { code: -32600, message: 'Error' } + error: { code: -32_600, message: 'Error' } }, timestamp: 2000 }); diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index b76c9a6df..624ce8f12 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -1,8 +1,19 @@ import { beforeEach, describe, expect, test } from 'vitest'; +import type { TaskStore } from '../../src/experimental/tasks/interfaces.js'; +import type { BaseContext } from '../../src/shared/context.js'; import { Protocol } from '../../src/shared/protocol.js'; import type { Transport } from '../../src/shared/transport.js'; -import type { EmptyResult, JSONRPCMessage, Notification, Request, Result } from '../../src/types/types.js'; +import type { + EmptyResult, + JSONRPCMessage, + JSONRPCRequest, + Notification, + Request, + Result, + TaskCreationParams +} from '../../src/types/types.js'; +import type { MessageExtraInfo } from '../../src/types/utility.js'; // Mock Transport class class MockTransport implements Transport { @@ -27,19 +38,68 @@ class MockTransport implements Transport { } } +/** + * Creates a mock BaseContext for testing. + */ +function createMockContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + sessionId?: string; +}): BaseContext { + return { + sessionId: args.sessionId, + mcpReq: { + id: args.request.id, + method: args.request.method, + _meta: args.request.params?._meta, + signal: args.abortController.signal, + send: async () => ({}) as never + }, + http: { + authInfo: undefined + }, + task: undefined, + notification: { + send: async () => {} + } + }; +} + +/** + * Creates a test Protocol class with all abstract methods implemented. + */ +function createTestProtocolClass() { + return class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): BaseContext { + return createMockContext({ + request: args.request, + abortController: args.abortController, + sessionId: args.capturedTransport?.sessionId + }); + } + }; +} + describe('Protocol transport handling bug', () => { let protocol: Protocol; let transportA: MockTransport; let transportB: MockTransport; beforeEach(() => { - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); transportA = new MockTransport('A'); transportB = new MockTransport('B'); @@ -86,8 +146,8 @@ describe('Protocol transport handling bug', () => { const results: { transport: string; response: JSONRPCMessage[] }[] = []; // Set up handler with variable delay based on request id - protocol.setRequestHandler('ping', async (_request, extra) => { - const delay = extra.requestId === 1 ? 50 : 10; + protocol.setRequestHandler('ping', async (_request, ctx) => { + const delay = ctx.mcpReq.id === 1 ? 50 : 10; await new Promise(resolve => setTimeout(resolve, delay)); return {}; }); diff --git a/packages/middleware/node/test/streamableHttp.test.ts b/packages/middleware/node/test/streamableHttp.test.ts index ca7728d88..ce6bb653e 100644 --- a/packages/middleware/node/test/streamableHttp.test.ts +++ b/packages/middleware/node/test/streamableHttp.test.ts @@ -12,7 +12,7 @@ import type { JSONRPCResultResponse, RequestId } from '@modelcontextprotocol/core'; -import type { EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; +import type { EventId, EventStore, ServerContext, ServerNotification, ServerRequest, StreamId } from '@modelcontextprotocol/server'; import { McpServer } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -214,8 +214,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'A user profile data tool', inputSchema: { active: z.boolean().describe('Profile status') } }, - async ({ active }, { authInfo }): Promise => { - return { content: [{ type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; + async ({ active }, ctx): Promise => { + return { + content: [{ type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${ctx.http?.authInfo?.token}!` }] + }; } ); @@ -404,11 +406,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'A simple test tool with request info', inputSchema: { name: z.string().describe('Name to greet') } }, - async ({ name }, { requestInfo }): Promise => { + async ({ name }, ctx): Promise => { // Convert Headers object to plain object for JSON serialization // Headers is a Web API class that doesn't serialize with JSON.stringify + const serverCtx = ctx as ServerContext; const serializedRequestInfo = { - headers: Object.fromEntries(requestInfo?.headers ?? new Headers()) + headers: Object.fromEntries(serverCtx.http?.req.headers ?? new Headers()) }; return { content: [ @@ -1851,9 +1854,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Register a tool that closes its own SSE stream via extra callback - mcpServer.registerTool('close-stream-tool', { description: 'Closes its own stream' }, async extra => { + mcpServer.registerTool('close-stream-tool', { description: 'Closes its own stream' }, async ctx => { // Close the SSE stream for this request - extra.closeSSEStream?.(); + ctx.http?.closeSSE?.(); streamCloseCalled = true; // Wait before returning so we can observe the stream closure @@ -1918,9 +1921,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeSSEStream callback was provided let receivedCloseSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; + // Register a tool that captures the ctx.http?.closeSSE callback + mcpServer.registerTool('test-callback-tool', { description: 'Test tool' }, async ctx => { + receivedCloseSSEStream = ctx.http?.closeSSE; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -1977,10 +1980,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { let receivedCloseSSEStream: (() => void) | undefined; let receivedCloseStandaloneSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-old-version-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; - receivedCloseStandaloneSSEStream = extra.closeStandaloneSSEStream; + // Register a tool that captures the ctx.http? callbacks + mcpServer.registerTool('test-old-version-tool', { description: 'Test tool' }, async ctx => { + receivedCloseSSEStream = ctx.http?.closeSSE; + receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2036,9 +2039,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeSSEStream callback was provided let receivedCloseSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-no-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; + // Register a tool that captures the ctx.http?.closeSSE callback + mcpServer.registerTool('test-no-callback-tool', { description: 'Test tool' }, async ctx => { + receivedCloseSSEStream = ctx.http?.closeSSE; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2093,9 +2096,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeStandaloneSSEStream callback was provided let receivedCloseStandaloneSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeStandaloneSSEStream callback - mcpServer.registerTool('test-standalone-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseStandaloneSSEStream = extra.closeStandaloneSSEStream; + // Register a tool that captures the ctx.http?.closeStandaloneSSE callback + mcpServer.registerTool('test-standalone-callback-tool', { description: 'Test tool' }, async ctx => { + receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2148,9 +2151,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Register a tool that closes the standalone SSE stream via extra callback - mcpServer.registerTool('close-standalone-stream-tool', { description: 'Closes standalone stream' }, async extra => { - extra.closeStandaloneSSEStream?.(); + // Register a tool that closes the standalone SSE stream via ctx callback + mcpServer.registerTool('close-standalone-stream-tool', { description: 'Closes standalone stream' }, async ctx => { + ctx.http?.closeStandaloneSSE?.(); return { content: [{ type: 'text', text: 'Stream closed' }] }; }); @@ -2230,8 +2233,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { mcpServer = result.mcpServer; // Register a tool that closes the standalone SSE stream - mcpServer.registerTool('close-standalone-for-reconnect', { description: 'Closes standalone stream' }, async extra => { - extra.closeStandaloneSSEStream?.(); + mcpServer.registerTool('close-standalone-for-reconnect', { description: 'Closes standalone stream' }, async ctx => { + ctx.http?.closeStandaloneSSE?.(); return { content: [{ type: 'text', text: 'Stream closed' }] }; }); diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 0b32be213..fc7aad00c 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -6,14 +6,15 @@ import type { AnySchema, CallToolResult, - CreateTaskRequestHandlerExtra, CreateTaskResult, GetTaskResult, Result, - TaskRequestHandlerExtra, + ServerNotification, + ServerRequest, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import type { ServerContext } from '../../server/context.js'; import type { BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ @@ -27,7 +28,7 @@ import type { BaseToolCallback } from '../../server/mcp.js'; export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Handler for task operations (get, getResult). @@ -36,7 +37,7 @@ export type CreateTaskRequestHandler< export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Interface for task-based tool handlers. diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index 6fd5a6cc5..c159cce5b 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -55,16 +55,16 @@ export class ExperimentalMcpServerTasks { * inputSchema: { input: z.string() }, * execution: { taskSupport: 'required' } * }, { - * createTask: async (args, extra) => { - * const task = await extra.taskStore.createTask({ ttl: 300000 }); + * createTask: async (args, ctx) => { + * const task = await ctx.task!.store.createTask({ ttl: 300000 }); * startBackgroundWork(task.taskId, args); * return { task }; * }, - * getTask: async (args, extra) => { - * return extra.taskStore.getTask(extra.taskId); + * getTask: async (args, ctx) => { + * return ctx.task!.store.getTask(ctx.task!.id); * }, - * getTaskResult: async (args, extra) => { - * return extra.taskStore.getTaskResult(extra.taskId); + * getTaskResult: async (args, ctx) => { + * return ctx.task!.store.getTaskResult(ctx.task!.id); * } * }); * ``` diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 1a8dbf143..b99333e2f 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,4 +1,5 @@ export * from './server/completable.js'; +export * from './server/context.js'; export * from './server/mcp.js'; export * from './server/middleware/hostHeaderValidation.js'; export * from './server/server.js'; diff --git a/packages/server/src/server/context.ts b/packages/server/src/server/context.ts new file mode 100644 index 000000000..861c73483 --- /dev/null +++ b/packages/server/src/server/context.ts @@ -0,0 +1,86 @@ +import type { + BaseContext, + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + LoggingMessageNotification, + Notification, + Request, + RequestOptions, + ServerNotification, + ServerRequest +} from '@modelcontextprotocol/core'; + +/** + * Server-specific context type for request handlers. + * Extends BaseContext with server-specific methods for logging, elicitation, and sampling. + * + * @typeParam RequestT - Additional request types beyond ServerRequest + * @typeParam NotificationT - Additional notification types beyond ServerNotification + */ +export type ServerContext = Omit< + BaseContext, + 'mcpReq' | 'http' | 'notification' +> & { + /** + * MCP request context containing protocol-level information and server-specific methods. + */ + mcpReq: BaseContext['mcpReq'] & { + /** + * Sends an elicitation request to the client. + */ + elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise; + /** + * Sends a sampling request to the client. + */ + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; + }; + + /** + * HTTP request context with authentication, raw Request object, and SSE controls. + */ + http?: BaseContext['http'] & { + /** + * The raw Request object (fetch API Request). + * Provides access to url, headers, and other request properties. + */ + req: globalThis.Request; + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + */ + closeSSE?: () => void; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + */ + closeStandaloneSSE?: () => void; + }; + + /** + * Notification context with send method and logging helpers. + */ + notification: BaseContext['notification'] & { + /** + * Sends a logging message to the client. + */ + log: (params: LoggingMessageNotification['params']) => Promise; + /** + * Sends a debug log message to the client. + */ + debug: (message: string, extraLogData?: Record) => Promise; + /** + * Sends an info log message to the client. + */ + info: (message: string, extraLogData?: Record) => Promise; + /** + * Sends a warning log message to the client. + */ + warning: (message: string, extraLogData?: Record) => Promise; + /** + * Sends an error log message to the client. + */ + error: (message: string, extraLogData?: Record) => Promise; + }; +}; diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index b2c17b71e..1fdf9315a 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -18,7 +18,6 @@ import type { PromptArgument, PromptReference, ReadResourceResult, - RequestHandlerExtra, Resource, ResourceTemplateReference, Result, @@ -54,6 +53,7 @@ import { ZodOptional } from 'zod'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; import { getCompleter, isCompletable } from './completable.js'; +import type { ServerContext } from './context.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; @@ -167,7 +167,7 @@ export class McpServer { }) ); - this.server.setRequestHandler('tools/call', async (request, extra): Promise => { + this.server.setRequestHandler('tools/call', async (request, ctx): Promise => { try { const tool = this._registeredTools[request.params.name]; if (!tool) { @@ -199,12 +199,12 @@ export class McpServer { // Handle taskSupport 'optional' without task augmentation - automatic polling if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, extra); + return await this.handleAutomaticTaskPolling(tool, request, ctx); } // Normal execution path const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, extra); + const result = await this.executeToolHandler(tool, args, ctx); // Return CreateTaskResult immediately for task requests if (isTaskRequest) { @@ -315,36 +315,36 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + ctx: ServerContext ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; if (isTaskHandler) { - if (!extra.taskStore) { + if (!ctx.task?.store) { throw new Error('No task store provided.'); } - const taskExtra = { ...extra, taskStore: extra.taskStore }; + const task = ctx; if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve(typedHandler.createTask(args as any, ctx)); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + return await Promise.resolve((typedHandler.createTask as any)(task)); } } if (tool.inputSchema) { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler(args as any, extra)); + return await Promise.resolve(typedHandler(args as any, ctx)); } else { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler as any)(extra)); + return await Promise.resolve((typedHandler as any)(ctx)); } } @@ -354,21 +354,20 @@ export class McpServer { private async handleAutomaticTaskPolling( tool: RegisteredTool, request: RequestT, - extra: RequestHandlerExtra + ctx: ServerContext ): Promise { - if (!extra.taskStore) { + if (!ctx.task?.store) { throw new Error('No task store provided for task-capable tool.'); } // Validate input and create task const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const handler = tool.handler as ToolTaskHandler; - const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, ctx)) : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(ctx)); // Poll until completion const taskId = createTaskResult.task.taskId; @@ -377,15 +376,12 @@ export class McpServer { while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await extra.taskStore.getTask(taskId); - if (!updatedTask) { - throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`); - } + const updatedTask = await ctx.task!.store.getTask(taskId); task = updatedTask; } // Return the final result - return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult; + return (await ctx.task!.store.getTaskResult(taskId)) as CallToolResult; } private _completionHandlerInitialized = false; @@ -491,7 +487,7 @@ export class McpServer { } }); - this.server.setRequestHandler('resources/list', async (_request, extra) => { + this.server.setRequestHandler('resources/list', async (_request, ctx) => { const resources = Object.entries(this._registeredResources) .filter(([_, resource]) => resource.enabled) .map(([uri, resource]) => ({ @@ -506,7 +502,7 @@ export class McpServer { continue; } - const result = await template.resourceTemplate.listCallback(extra); + const result = await template.resourceTemplate.listCallback(ctx); for (const resource of result.resources) { templateResources.push({ ...template.metadata, @@ -529,7 +525,7 @@ export class McpServer { return { resourceTemplates }; }); - this.server.setRequestHandler('resources/read', async (request, extra) => { + this.server.setRequestHandler('resources/read', async (request, ctx) => { const uri = new URL(request.params.uri); // First check for exact resource match @@ -538,14 +534,14 @@ export class McpServer { if (!resource.enabled) { throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); } - return resource.readCallback(uri, extra); + return resource.readCallback(uri, ctx); } // Then check templates for (const template of Object.values(this._registeredResourceTemplates)) { const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); if (variables) { - return template.readCallback(uri, variables, extra); + return template.readCallback(uri, variables, ctx); } } @@ -587,7 +583,7 @@ export class McpServer { }) ); - this.server.setRequestHandler('prompts/get', async (request, extra): Promise => { + this.server.setRequestHandler('prompts/get', async (request, ctx): Promise => { const prompt = this._registeredPrompts[request.params.name]; if (!prompt) { throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); @@ -608,11 +604,11 @@ export class McpServer { const args = parseResult.data; const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); + return await Promise.resolve(cb(args, ctx)); } else { const cb = prompt.callback as PromptCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); + return await Promise.resolve((cb as any)(ctx)); } }); @@ -881,7 +877,7 @@ export class McpServer { /** * Registers a prompt with a config object and callback. */ - registerPrompt( + registerPrompt( name: string, config: { title?: string; @@ -1016,13 +1012,13 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends ServerContext, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat - ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise + ? (args: ShapeOutput, ctx: Extra) => SendResultT | Promise : Args extends AnySchema - ? (args: SchemaOutput, extra: Extra) => SendResultT | Promise - : (extra: Extra) => SendResultT | Promise; + ? (args: SchemaOutput, ctx: Extra) => SendResultT | Promise + : (ctx: Extra) => SendResultT | Promise; /** * Callback for a tool handler registered with Server.tool(). @@ -1036,7 +1032,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - RequestHandlerExtra, + ServerContext, Args >; @@ -1155,7 +1151,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + ctx: ServerContext ) => ListResourcesResult | Promise; /** @@ -1163,7 +1159,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra + ctx: ServerContext ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1191,7 +1187,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra + ctx: ServerContext ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1216,8 +1212,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + ? (args: ShapeOutput, ctx: ServerContext) => GetPromptResult | Promise + : (ctx: ServerContext) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index ca43272cf..14a666dbc 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -1,37 +1,46 @@ import type { + AnySchema, + BaseContext, ClientCapabilities, CreateMessageRequest, CreateMessageRequestParamsBase, CreateMessageRequestParamsWithTools, CreateMessageResult, CreateMessageResultWithTools, + ElicitRequest, ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, Implementation, InitializeRequest, InitializeResult, + JSONRPCRequest, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, LoggingLevel, LoggingMessageNotification, + MessageExtraInfo, Notification, NotificationOptions, ProtocolOptions, Request, - RequestHandlerExtra, RequestMethod, RequestOptions, RequestTypeMap, ResourceUpdatedNotification, Result, + SchemaOutput, ServerCapabilities, ServerNotification, ServerRequest, ServerResult, + TaskContext, + TaskCreationParams, + TaskStore, ToolResultContent, - ToolUseContent + ToolUseContent, + Transport } from '@modelcontextprotocol/core'; import { AjvJsonSchemaValidator, @@ -56,6 +65,7 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +import type { ServerContext } from './context.js'; export type ServerOptions = ProtocolOptions & { /** @@ -159,9 +169,9 @@ export class Server< this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); if (this._capabilities.logging) { - this.setRequestHandler('logging/setLevel', async (request, extra) => { + this.setRequestHandler('logging/setLevel', async (request, ctx) => { const transportSessionId: string | undefined = - extra.sessionId || (extra.requestInfo?.headers.get('mcp-session-id') as string) || undefined; + ctx.sessionId || (ctx.http?.req.headers.get('mcp-session-id') as string) || undefined; const { level } = request.params; const parseResult = LoggingLevelSchema.safeParse(level); if (parseResult.success) { @@ -219,13 +229,13 @@ export class Server< method: M, handler: ( request: RequestTypeMap[M], - extra: RequestHandlerExtra + ctx: ServerContext ) => ServerResult | ResultT | Promise ): void { if (method === 'tools/call') { const wrappedHandler = async ( request: RequestTypeMap[M], - extra: RequestHandlerExtra + ctx: BaseContext ): Promise => { const validatedRequest = safeParse(CallToolRequestSchema, request); if (!validatedRequest.success) { @@ -236,7 +246,7 @@ export class Server< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx as ServerContext)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -263,11 +273,17 @@ export class Server< }; // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler as unknown as typeof handler); + return super.setRequestHandler(method, wrappedHandler); } - // Other handlers use default behavior - return super.setRequestHandler(method, handler); + // Other handlers use default behavior - cast is safe because Server.createRequestContext always builds a ServerContext + return super.setRequestHandler( + method, + handler as ( + request: RequestTypeMap[M], + ctx: BaseContext + ) => ServerResult | ResultT | Promise + ); } protected assertCapabilityForMethod(method: RequestT['method']): void { @@ -463,6 +479,104 @@ export class Server< return this._capabilities; } + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): BaseContext { + const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + const sessionId = capturedTransport?.sessionId; + + // Build the task context using the helper from Protocol + const task: TaskContext | undefined = this.buildTaskContext({ + taskStore, + request, + sessionId, + relatedTaskId, + taskCreationParams + }); + + // Closure helpers for sendRequest and sendNotification + const sendRequest = async ( + req: ServerRequest | RequestT, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + const taskId = task?.id; + if (taskId) { + requestOptions.relatedTask = { taskId }; + if (task?.store) { + await task.store.updateTaskStatus(taskId, 'input_required'); + } + } + return await this.request(req, resultSchema, requestOptions); + }; + + const sendNotification = async (notification: ServerNotification | NotificationT): Promise => { + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (task && task.id) { + notificationOptions.relatedTask = { taskId: task.id }; + } + return this.notification(notification, notificationOptions); + }; + + // Build the ServerContext POJO + const ctx: ServerContext = { + sessionId, + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta, + signal: abortController.signal, + send: sendRequest, + elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions): Promise => { + const elicitRequest: ElicitRequest = { + method: 'elicitation/create', + params + }; + return await this.request(elicitRequest, ElicitResultSchema, { ...options, relatedRequestId: request.id }); + }, + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => { + return this.createMessage(params, options); + } + }, + http: extra?.request + ? { + req: extra.request, + authInfo: extra?.authInfo, + closeSSE: extra?.closeSSEStream, + closeStandaloneSSE: extra?.closeStandaloneSSEStream + } + : undefined, + task, + notification: { + send: sendNotification, + log: async (params: LoggingMessageNotification['params']) => { + await this.sendLoggingMessage(params); + }, + debug: async (message: string, extraLogData?: Record) => { + await this.sendLoggingMessage({ level: 'debug', data: { ...extraLogData, message }, logger: 'server' }); + }, + info: async (message: string, extraLogData?: Record) => { + await this.sendLoggingMessage({ level: 'info', data: { ...extraLogData, message }, logger: 'server' }); + }, + warning: async (message: string, extraLogData?: Record) => { + await this.sendLoggingMessage({ level: 'warning', data: { ...extraLogData, message }, logger: 'server' }); + }, + error: async (message: string, extraLogData?: Record) => { + await this.sendLoggingMessage({ level: 'error', data: { ...extraLogData, message }, logger: 'server' }); + } + } + }; + + return ctx; + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index ae8bad97e..fe62fc9e5 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -9,7 +9,7 @@ import { TextEncoder } from 'node:util'; -import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, RequestInfo, Transport } from '@modelcontextprotocol/core'; +import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; import { DEFAULT_NEGOTIATED_PROTOCOL_VERSION, isInitializeRequest, @@ -598,11 +598,6 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { return this.createJsonErrorResponse(415, -32_000, 'Unsupported Media Type: Content-Type must be application/json'); } - // Build request info from headers - const requestInfo: RequestInfo = { - headers: req.headers - }; - let rawMessage; if (options?.parsedBody === undefined) { try { @@ -667,7 +662,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { if (!hasRequests) { // if it only contains notifications or responses, return 202 for (const message of messages) { - this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo }); + this.onmessage?.(message, { authInfo: options?.authInfo, request: req }); } return new Response(null, { status: 202 }); } @@ -701,7 +696,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { } for (const message of messages) { - this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo }); + this.onmessage?.(message, { authInfo: options?.authInfo, request: req }); } }); } @@ -771,7 +766,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { }; } - this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo, closeSSEStream, closeStandaloneSSEStream }); + this.onmessage?.(message, { authInfo: options?.authInfo, request: req, closeSSEStream, closeStandaloneSSEStream }); } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses // This will be handled by the send() method when responses are ready diff --git a/src/conformance/everything-client.ts b/src/conformance/everything-client.ts index d6436d1dd..19ba5d648 100644 --- a/src/conformance/everything-client.ts +++ b/src/conformance/everything-client.ts @@ -12,12 +12,7 @@ * consolidating all the individual test clients into one. */ -import { - Client, - StreamableHTTPClientTransport, - ClientCredentialsProvider, - PrivateKeyJwtProvider -} from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport, ClientCredentialsProvider, PrivateKeyJwtProvider } from '@modelcontextprotocol/client'; import { z } from 'zod'; import { withOAuthRetry, handle401 } from './helpers/withOAuthRetry.js'; import { logger } from './helpers/logger.js'; diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 4be643aef..2ce39769b 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -10,12 +10,7 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult, GetPromptResult, ReadResourceResult, EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; -import { - ElicitResultSchema, - isInitializeRequest, - McpServer, - ResourceTemplate -} from '@modelcontextprotocol/server'; +import { ElicitResultSchema, isInitializeRequest, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; import type { Request, Response } from 'express'; import cors from 'cors'; @@ -207,8 +202,8 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that emits log messages during execution', inputSchema: {} }, - async (_args, extra): Promise => { - await extra.sendNotification({ + async (_args, ctx): Promise => { + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info', @@ -217,7 +212,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info', @@ -226,7 +221,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info', @@ -246,10 +241,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that reports progress notifications', inputSchema: {} }, - async (_args, extra): Promise => { - const progressToken = extra._meta?.progressToken ?? 0; + async (_args, ctx): Promise => { + const progressToken = ctx.mcpReq._meta?.progressToken ?? 0; console.log('Progress token:', progressToken); - await extra.sendNotification({ + await ctx.notification.send({ method: 'notifications/progress', params: { progressToken, @@ -260,7 +255,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.notification.send({ method: 'notifications/progress', params: { progressToken, @@ -271,7 +266,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.notification.send({ method: 'notifications/progress', params: { progressToken, @@ -306,23 +301,23 @@ function createMcpServer(sessionId?: string) { 'Tests SSE stream disconnection and client reconnection (SEP-1699). Server will close the stream mid-call and send the result after client reconnects.', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - console.log(`[${extra.sessionId}] Starting test_reconnection tool...`); + console.log(`[${ctx.sessionId}] Starting test_reconnection tool...`); // Get the transport for this session - const transport = extra.sessionId ? transports[extra.sessionId] : undefined; - if (transport && extra.requestId) { + const transport = ctx.sessionId ? transports[ctx.sessionId] : undefined; + if (transport && ctx.mcpReq.id) { // Close the SSE stream to trigger client reconnection - console.log(`[${extra.sessionId}] Closing SSE stream to trigger client polling...`); - transport.closeSSEStream(extra.requestId); + console.log(`[${ctx.sessionId}] Closing SSE stream to trigger client polling...`); + transport.closeSSEStream(ctx.mcpReq.id); } // Wait for client to reconnect (should respect retry field) await sleep(100); - console.log(`[${extra.sessionId}] test_reconnection tool complete`); + console.log(`[${ctx.sessionId}] test_reconnection tool complete`); return { content: [ @@ -344,10 +339,10 @@ function createMcpServer(sessionId?: string) { prompt: z.string().describe('The prompt to send to the LLM') } }, - async (args: { prompt: string }, extra): Promise => { + async (args: { prompt: string }, ctx): Promise => { try { // Request sampling from client - const result = (await extra.sendRequest( + const result = (await ctx.mcpReq.send( { method: 'sampling/createMessage', params: { @@ -398,10 +393,10 @@ function createMcpServer(sessionId?: string) { message: z.string().describe('The message to show the user') } }, - async (args: { message: string }, extra): Promise => { + async (args: { message: string }, ctx): Promise => { try { // Request user input from client - const result = await extra.sendRequest( + const result = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -450,10 +445,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with default values per SEP-1034', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with default values for all primitive types - const result = await extra.sendRequest( + const result = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -524,10 +519,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with enum schema improvements per SEP-1330', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with all 5 enum schema variants - const result = await extra.sendRequest( + const result = await ctx.mcpReq.send( { method: 'elicitation/create', params: { diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index a4d6943c1..d021a3c97 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -1095,10 +1095,10 @@ test('should handle request timeout', async () => { ); // Set up server with a delayed response - server.setRequestHandler('resources/list', async (_request, extra) => { + server.setRequestHandler('resources/list', async (_request, ctx) => { const timer = new Promise(resolve => { const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener('abort', () => clearTimeout(timeout)); + ctx.mcpReq.signal.addEventListener('abort', () => clearTimeout(timeout)); }); await timer; @@ -2268,27 +2268,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2344,27 +2344,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2421,27 +2421,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Result data!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2502,27 +2502,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2599,18 +2599,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2692,18 +2692,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2784,18 +2784,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'result-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2875,18 +2875,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2987,27 +2987,27 @@ describe('Task-based execution', () => { } }, { - async createTask({ id }, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask({ id }, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -3255,27 +3255,27 @@ test('should respect server task capabilities', async () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 52da4fc34..d32a70386 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -1384,12 +1384,12 @@ test('should handle request timeout', async () => { } ); - client.setRequestHandler('sampling/createMessage', async (_request, extra) => { + client.setRequestHandler('sampling/createMessage', async (_request, ctx) => { await new Promise((resolve, reject) => { const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener('abort', () => { + ctx.mcpReq.signal.addEventListener('abort', () => { clearTimeout(timeout); - reject(extra.signal.reason); + reject(ctx.mcpReq.signal.reason); }); }); @@ -2144,9 +2144,9 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); // Simulate some async work @@ -2155,20 +2155,20 @@ describe('Task-based execution', () => { const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2336,13 +2336,13 @@ describe('Task-based execution', () => { let capturedElicitRequest: z4.infer | null = null; // Set up client elicitation handler - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { let taskId: string | undefined; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const createdTask = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); taskId = createdTask.taskId; } @@ -2366,15 +2366,15 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); // Perform async work that makes a nested request (async () => { - // During tool execution, make a nested request to the client using extra.sendRequest - const elicitResult = await extra.sendRequest( + // During tool execution, make a nested request to the client using ctx.sendRequest + const elicitResult = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -2400,20 +2400,20 @@ describe('Task-based execution', () => { } ] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2490,18 +2490,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'server-test-user', confirmed: true } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2571,18 +2571,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2650,18 +2650,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'result-user', confirmed: true } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2731,18 +2731,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2845,9 +2845,9 @@ describe('Task-based execution', () => { } }, { - async createTask({ delay, taskNum }, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask({ delay, taskNum }, ctx) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); // Simulate async work @@ -2856,20 +2856,20 @@ describe('Task-based execution', () => { const result = { content: [{ type: 'text', text: `Completed task ${taskNum || 'unknown'}` }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -3074,18 +3074,18 @@ test('should respect client task capabilities', async () => { } ); - client.setRequestHandler('elicitation/create', async (request, extra) => { + client.setRequestHandler('elicitation/create', async (request, ctx) => { const result = { action: 'accept', content: { username: 'test-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.task!.store) { + const task = await ctx.task!.store.createTask({ + ttl: ctx.task!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } diff --git a/test/integration/test/server/context.test.ts b/test/integration/test/server/context.test.ts new file mode 100644 index 000000000..4692d2cfd --- /dev/null +++ b/test/integration/test/server/context.test.ts @@ -0,0 +1,265 @@ +import { Client } from '@modelcontextprotocol/client'; +import type { BaseContext, ServerNotification, ServerRequest } from '@modelcontextprotocol/core'; +import { + CallToolResultSchema, + GetPromptResultSchema, + InMemoryTransport, + ListResourcesResultSchema, + ReadResourceResultSchema +} from '@modelcontextprotocol/core'; +import type { ServerContext } from '@modelcontextprotocol/server'; +import { McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; +import { z } from 'zod/v4'; + +describe('ServerContext', () => { + /*** + * Test: `ctx` provided to callbacks is ServerContext (parameterized) + */ + type Seen = { isContext: boolean; hasRequestId: boolean }; + const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = + [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + (_args: { name: string }, ctx) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, ctx) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + }); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + const template = new ResourceTemplate('test://items/{id}', { + list: async ctx => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + mcpServer.registerPrompt('ctx-prompt', {}, async ctx => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + + const logLevelsThroughContext = ['debug', 'info', 'warning', 'error'] as const; + + //it.each for each log level, test that logging message is sent to client + it.each(logLevelsThroughContext)('should send logging message to client for %s level from ServerContext', async level => { + const mcpServer = new McpServer( + { name: 'ctx-test', version: '1.0' }, + { + capabilities: { + logging: {} + } + } + ); + const client = new Client( + { name: 'ctx-client', version: '1.0' }, + { + capabilities: {} + } + ); + + let seen = 0; + + client.setNotificationHandler('notifications/message', notification => { + seen++; + expect(notification.params.level).toBe(level); + return; + }); + + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, ctx) => { + const serverCtx = ctx as ServerContext; + // Use the new notification API (no sessionId parameter) + await serverCtx.notification[level]('Test message', { test: 'test' }); + await serverCtx.notification.log({ + level, + data: 'Test message', + logger: 'test-logger-namespace' + }); + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { name: 'ctx-log-test', arguments: { name: 'ctx-log-test-name' } } + }, + CallToolResultSchema + ); + + // two messages should have been sent - one from the .log method and one from the .debug/info/warning/error method + expect(seen).toBe(2); + + expect(result.content).toHaveLength(1); + expect(result.content[0]).toMatchObject({ + type: 'text', + text: 'ok' + }); + }); + describe('BaseContext compatibility', () => { + const contextCases: Array< + [string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise] + > = [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + // The test is to ensure that the ctx is compatible with the BaseContext type + (_args: { name: string }, ctx: BaseContext) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the BaseContext type + mcpServer.registerResource( + 'ctx-resource', + 'test://res/1', + { title: 'ctx-resource' }, + async (_uri, ctx: BaseContext) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + } + ); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the BaseContext type + const template = new ResourceTemplate('test://items/{id}', { + list: async (ctx: BaseContext) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the BaseContext type + mcpServer.registerPrompt('ctx-prompt', {}, async (ctx: BaseContext) => { + seen.isContext = 'elicitInput' in ctx.mcpReq; + seen.hasRequestId = !!ctx.mcpReq.id; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + }); +}); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index eae10f087..73f1b9685 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,5 +1,5 @@ import { Client } from '@modelcontextprotocol/client'; -import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; +import type { BaseContext, CallToolResult, Notification, ServerNotification, ServerRequest, TextContent } from '@modelcontextprotocol/core'; import { CallToolResultSchema, CompleteResultSchema, @@ -16,6 +16,7 @@ import { UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; +import type { ServerContext } from '@modelcontextprotocol/server'; import { completable, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -115,13 +116,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { steps: z.number().min(1).describe('Number of steps to perform') } }, - async ({ steps }, { sendNotification, _meta }) => { - const progressToken = _meta?.progressToken; + async ({ steps }, ctx) => { + const progressToken = ctx.mcpReq._meta?.progressToken; if (progressToken) { // Send progress notification for each step for (let i = 1; i <= steps; i++) { - await sendNotification({ + await ctx.notification.send({ method: 'notifications/progress', params: { progressToken, @@ -242,7 +243,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sendNotification: () => { throw new Error('Not implemented'); } - }); + } as unknown as ServerContext); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); }); @@ -1323,7 +1324,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Session ID to Tool Callback */ - test('should pass sessionId to tool callback via RequestHandlerExtra', async () => { + test('should pass sessionId to tool callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -1335,8 +1336,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedSessionId: string | undefined; - mcpServer.registerTool('test-tool', {}, async extra => { - receivedSessionId = extra.sessionId; + mcpServer.registerTool('test-tool', {}, async ctx => { + receivedSessionId = ctx.sessionId; return { content: [ { @@ -1369,7 +1370,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Tool Callback */ - test('should pass requestId to tool callback via RequestHandlerExtra', async () => { + test('should pass requestId to tool callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -1381,13 +1382,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerTool('request-id-test', {}, async extra => { - receivedRequestId = extra.requestId; + mcpServer.registerTool('request-id-test', {}, async ctx => { + receivedRequestId = ctx.mcpReq.id; return { content: [ { type: 'text', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpReq.id}` } ] }; @@ -1444,7 +1445,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); mcpServer.registerTool('test-tool', {}, async extra => { - await extra.sendNotification({ + await extra.notification.send({ method: 'notifications/message', params: { level: 'debug', data: loggingMessage } }); @@ -1790,17 +1791,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000 }); + createTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) throw new Error('Task not found'); + getTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + getTaskResult: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + return (await ctx.task.store.getTaskResult(ctx.task.id!)) as CallToolResult; } } ); @@ -1859,17 +1862,18 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000 }); + createTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + getTask: async (_args, ctx) => { + const task = await ctx.task?.store?.getTask(ctx.task.id!); if (!task) throw new Error('Task not found'); return task; }, - getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + getTaskResult: async (_args, ctx) => { + return (await ctx.task?.store?.getTaskResult(ctx.task.id!)) as CallToolResult; } } ); @@ -2844,7 +2848,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Resource Callback */ - test('should pass requestId to resource callback via RequestHandlerExtra', async () => { + test('should pass requestId to resource callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -2856,13 +2860,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerResource('request-id-test', 'test://resource', {}, async (_uri, extra) => { - receivedRequestId = extra.requestId; + mcpServer.registerResource('request-id-test', 'test://resource', {}, async (_uri, ctx) => { + receivedRequestId = ctx.mcpReq.id; return { contents: [ { uri: 'test://resource', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpReq.id}` } ] }; @@ -3782,7 +3786,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Prompt Callback */ - test('should pass requestId to prompt callback via RequestHandlerExtra', async () => { + test('should pass requestId to prompt callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -3794,15 +3798,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerPrompt('request-id-test', {}, async extra => { - receivedRequestId = extra.requestId; + mcpServer.registerPrompt('request-id-test', {}, async ctx => { + receivedRequestId = ctx.mcpReq.id; return { messages: [ { role: 'assistant', content: { type: 'text', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpReq.id}` } } ] @@ -4309,17 +4313,20 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }) => ({ - messages: [ - { - role: 'assistant', - content: { - type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + async (args, ctx: BaseContext) => { + expect('elicitInput' in ctx.mcpReq).toBe(true); + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${args.name}, welcome to the ${args.department} team!` + } } - } - ] - }) + ] + }; + } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -6222,11 +6229,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ input }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ input }, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.task.store; // Simulate async work setTimeout(async () => { @@ -6237,15 +6245,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async (_input, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_input, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } @@ -6327,11 +6334,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ value }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ value }, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.task.store; // Simulate async work setTimeout(async () => { @@ -6343,15 +6351,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async (_value, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_value, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } @@ -6435,14 +6442,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ data }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ data }, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.task.store; // Simulate async work setTimeout(async () => { + if (!store) throw new Error('Task store not found'); await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Completed: ${data}` }] }); @@ -6451,15 +6460,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async (_data, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_data, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } @@ -6552,11 +6560,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ctx => { + if (!ctx.task) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.task.store; // Simulate async failure setTimeout(async () => { @@ -6569,15 +6578,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async ctx => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async ctx => { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id!); return result as CallToolResult; } } @@ -6658,11 +6665,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ctx => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.task.store; // Simulate async cancellation setTimeout(async () => { @@ -6672,15 +6680,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async ctx => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async ctx => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } @@ -6745,19 +6752,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const task = await ctx.task.store.getTask(ctx.task.id!); return task; }, - getTaskResult: async (_args, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_args, ctx) => { + if (!ctx.task?.store) throw new Error('Task store not found'); + const result = await ctx.task.store.getTaskResult(ctx.task.id!); return result as CallToolResult; } } diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index efb8ad6db..a8a8046dd 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -63,8 +63,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ duration, shouldFail }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ duration, shouldFail }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -75,11 +75,11 @@ describe('Task Lifecycle Integration Tests', () => { try { await (shouldFail - ? extra.taskStore.storeTaskResult(task.taskId, 'failed', { + ? ctx.task!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: 'Task failed as requested' }], isError: true }) - : extra.taskStore.storeTaskResult(task.taskId, 'completed', { + : ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Completed after ${duration}ms` }] })); } catch { @@ -89,15 +89,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -114,8 +114,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ userName }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ userName }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -128,14 +128,14 @@ describe('Task Lifecycle Integration Tests', () => { if (userName) { // Complete immediately if userName was provided try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Hello, ${userName}!` }] }); } catch { // Task may have been cleaned up if test ended } } else { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -160,7 +160,7 @@ describe('Task Lifecycle Integration Tests', () => { ? elicitationResult.content.userName : 'Unknown'; try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Hello, ${name}!` }] }); } catch { @@ -171,15 +171,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -418,8 +418,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ requestCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ requestCount }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -432,7 +432,7 @@ describe('Task Lifecycle Integration Tests', () => { // Send multiple elicitation requests for (let i = 0; i < requestCount; i++) { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -458,7 +458,7 @@ describe('Task Lifecycle Integration Tests', () => { // Complete with all responses try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Received responses: ${responses.join(', ')}` }] }); } catch { @@ -468,15 +468,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -907,8 +907,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -921,8 +921,8 @@ describe('Task Lifecycle Integration Tests', () => { // Queue multiple elicitation requests for (let i = 0; i < messageCount; i++) { // Send request but don't await - let it queue - extra - .sendRequest( + ctx.mcpReq + .send( { method: 'elicitation/create', params: { @@ -957,15 +957,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -1105,8 +1105,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount, delayBetweenMessages }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount, delayBetweenMessages }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -1121,7 +1121,7 @@ describe('Task Lifecycle Integration Tests', () => { // Send messages with delays between them for (let i = 0; i < messageCount; i++) { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.mcpReq.send( { method: 'elicitation/create', params: { @@ -1152,7 +1152,7 @@ describe('Task Lifecycle Integration Tests', () => { // Complete with all responses try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Received all responses: ${responses.join(', ')}` }] }); } catch { @@ -1161,7 +1161,7 @@ describe('Task Lifecycle Integration Tests', () => { } catch (error) { // Handle errors try { - await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: `Error: ${error}` }], isError: true }); @@ -1173,15 +1173,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -1321,8 +1321,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount }, ctx) { + const task = await ctx.task!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -1335,8 +1335,8 @@ describe('Task Lifecycle Integration Tests', () => { for (let i = 0; i < messageCount; i++) { // Start the request but don't wait for response // The request gets queued when sendRequest is called - extra - .sendRequest( + ctx.mcpReq + .send( { method: 'elicitation/create', params: { @@ -1361,7 +1361,7 @@ describe('Task Lifecycle Integration Tests', () => { // Complete the task after all messages are queued try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Task completed quickly' }] }); } catch { @@ -1370,7 +1370,7 @@ describe('Task Lifecycle Integration Tests', () => { } catch (error) { // Handle errors try { - await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + await ctx.task!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: `Error: ${error}` }], isError: true }); @@ -1382,15 +1382,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.task!.store.getTask(ctx.task!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.task!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.task!.store.getTaskResult(ctx.task!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 6b6849aa9..c7d1ec142 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -68,9 +68,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { message: z.string().describe('Message to send').default('Test notification') } }, - async ({ message }, { sendNotification }) => { + async ({ message }, ctx) => { // Send notification immediately - await sendNotification({ + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info', @@ -94,10 +94,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { interval: z.number().describe('Interval between notifications in ms').default(50) } }, - async ({ count, interval }, { sendNotification }) => { + async ({ count, interval }, ctx) => { // Send notifications at specified intervals for (let i = 0; i < count; i++) { - await sendNotification({ + await ctx.notification.send({ method: 'notifications/message', params: { level: 'info',