diff --git a/.changeset/calm-houses-listen.md b/.changeset/calm-houses-listen.md new file mode 100644 index 000000000..3a0fa18da --- /dev/null +++ b/.changeset/calm-houses-listen.md @@ -0,0 +1,9 @@ +--- +"@sei-js/mcp-server": patch +--- + +Fix session binding and response isolation in the HTTP SSE transport. + +- POST handler now validates `sessionId` on every request — rejects missing session IDs (400) and unknown session IDs (404) +- Each POST is routed to the transport instance that owns the matching session ID, preventing cross-client request injection +- Session IDs now use the MCP SDK's `transport.sessionId` rather than `Date.now()` \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 029e34bfd..656a42c4d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "packages/registry/community-assetlist"] path = packages/registry/community-assetlist - url = https://github.com/Sei-Public-Goods/sei-assetlist.git + url = https://github.com/Seitrace/sei-assetlist.git [submodule "packages/registry/chain-registry"] path = packages/registry/chain-registry url = https://github.com/sei-protocol/chain-registry.git diff --git a/packages/mcp-server/src/server/transport/http-sse.ts b/packages/mcp-server/src/server/transport/http-sse.ts index 31c78c5ff..16042c509 100644 --- a/packages/mcp-server/src/server/transport/http-sse.ts +++ b/packages/mcp-server/src/server/transport/http-sse.ts @@ -40,12 +40,12 @@ export class HttpSseTransport implements McpTransport { this.app.get(this.path, (req: Request, res: Response) => { console.error(`SSE connection from ${req.ip}`); - + // Create SSE transport - it will handle headers automatically const transport = new SSEServerTransport(`${this.path}/message`, res); - const sessionId = Date.now().toString(); + const sessionId = transport.sessionId this.connections.set(sessionId, transport); - + // Connect transport to MCP server if (this.mcpServer) { this.mcpServer.connect(transport); @@ -61,10 +61,15 @@ export class HttpSseTransport implements McpTransport { // Message endpoint for SSE transport this.app.post(`${this.path}/message`, async (req: Request, res: Response) => { try { - // Find the first available transport (simple approach for now) - const transport = Array.from(this.connections.values())[0]; + const sessionId = typeof req.query.sessionId === 'string' ? req.query.sessionId : undefined; + if (!sessionId) { + res.status(400).json({ error: 'Missing sessionId' }); + return; + } + + const transport = this.connections.get(sessionId); if (!transport) { - res.status(404).json({ error: 'No active SSE connection' }); + res.status(404).json({ error: 'Session not found' }); return; } diff --git a/packages/mcp-server/src/tests/server/transport/http-sse.test.ts b/packages/mcp-server/src/tests/server/transport/http-sse.test.ts index 243e22bca..4861b58d8 100644 --- a/packages/mcp-server/src/tests/server/transport/http-sse.test.ts +++ b/packages/mcp-server/src/tests/server/transport/http-sse.test.ts @@ -12,7 +12,7 @@ jest.mock('express', () => { listen: jest.fn() }; const express = jest.fn(() => mockApp); - express.json = jest.fn(); + (express as any).json = jest.fn(); return express; }); @@ -65,7 +65,8 @@ describe('HttpSseTransport', () => { }; mockTransport = { - handleMessage: jest.fn() + handleMessage: jest.fn(), + sessionId: 'mock-session-id' }; mockMcpServer = { @@ -74,7 +75,7 @@ describe('HttpSseTransport', () => { // Configure mocks mockExpress.mockReturnValue(mockApp); - mockExpress.json = jest.fn().mockReturnValue('json-middleware'); + (mockExpress as any).json = jest.fn().mockReturnValue('json-middleware'); mockCreateCorsMiddleware.mockReturnValue('cors-middleware'); mockSSEServerTransport.mockImplementation(() => mockTransport); @@ -129,20 +130,20 @@ describe('HttpSseTransport', () => { describe('SSE endpoint', () => { it('should create SSE transport and connect to MCP server', () => { const transport = new HttpSseTransport(3000, 'localhost', '/sse'); - + // Mock MCP server (transport as any).mcpServer = mockMcpServer; - - const mockReq = { + + const mockReq = { ip: '127.0.0.1', on: jest.fn() }; const mockRes = {}; - + // Get the SSE endpoint handler const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; sseHandler(mockReq, mockRes); - + expect(consoleErrorSpy).toHaveBeenCalledWith('SSE connection from 127.0.0.1'); expect(mockMcpServer.connect).toHaveBeenCalledWith(mockTransport); expect(mockReq.on).toHaveBeenCalledWith('close', expect.any(Function)); @@ -150,99 +151,197 @@ describe('HttpSseTransport', () => { it('should handle connection without MCP server', () => { new HttpSseTransport(3000, 'localhost', '/sse'); - - const mockReq = { + + const mockReq = { ip: '127.0.0.1', on: jest.fn() }; const mockRes = {}; - + // Get the SSE endpoint handler const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; sseHandler(mockReq, mockRes); - + expect(consoleErrorSpy).toHaveBeenCalledWith('SSE connection from 127.0.0.1'); expect(mockMcpServer.connect).not.toHaveBeenCalled(); }); it('should clean up connection on close', () => { new HttpSseTransport(3000, 'localhost', '/sse'); - - const mockReq = { + + const mockReq = { ip: '127.0.0.1', on: jest.fn() }; const mockRes = {}; - + // Get the SSE endpoint handler const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; sseHandler(mockReq, mockRes); - + // Get the close handler const closeHandler = mockReq.on.mock.calls.find(call => call[0] === 'close')[1]; closeHandler(); - - expect(consoleErrorSpy).toHaveBeenCalledWith(expect.stringMatching(/SSE connection closed for session \d+/)); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + `SSE connection closed for session ${mockTransport.sessionId}` + ); + }); + + it('should use the transport sessionId as the connections key', () => { + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + + const mockReq = { ip: '127.0.0.1', on: jest.fn() }; + const mockRes = {}; + + const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; + sseHandler(mockReq, mockRes); + + expect(mockSSEServerTransport).toHaveBeenCalledWith('/sse/message', mockRes); + expect((transport as any).connections.has(mockTransport.sessionId)).toBe(true); + }); + + it('should assign unique session IDs to concurrent connections', () => { + const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' }; + const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' }; + mockSSEServerTransport + .mockImplementationOnce(() => mockTransport1) + .mockImplementationOnce(() => mockTransport2); + + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + (transport as any).mcpServer = mockMcpServer; + + const mockReq1 = { ip: '127.0.0.1', on: jest.fn() }; + const mockReq2 = { ip: '127.0.0.2', on: jest.fn() }; + + const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; + sseHandler(mockReq1, {}); + sseHandler(mockReq2, {}); + + expect((transport as any).connections.has('session-id-1')).toBe(true); + expect((transport as any).connections.has('session-id-2')).toBe(true); + expect((transport as any).connections.size).toBe(2); }); }); describe('Message endpoint', () => { it('should handle message with active transport', async () => { const transport = new HttpSseTransport(3000, 'localhost', '/sse'); - + // Add a connection to the transport (transport as any).connections.set('test-session', mockTransport); - - const mockReq = { body: { test: 'message' } }; - const mockRes = { + + const mockReq = { body: { test: 'message' }, query: { sessionId: 'test-session' } }; + const mockRes = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn() }; - + // Get the message endpoint handler const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; await messageHandler(mockReq, mockRes); - + expect(mockTransport.handleMessage).toHaveBeenCalledWith({ test: 'message' }); expect(mockRes.status).toHaveBeenCalledWith(200); expect(mockRes.end).toHaveBeenCalled(); }); - it('should return 404 when no active transport', async () => { + it('should return 400 when sessionId query param is missing', async () => { new HttpSseTransport(3000, 'localhost', '/sse'); - - const mockReq = { body: { test: 'message' } }; - const mockRes = { + + const mockReq = { body: { test: 'message' }, query: {} }; + const mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() }; - + // Get the message endpoint handler const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; await messageHandler(mockReq, mockRes); - + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith({ error: 'Missing sessionId' }); + }); + + it('should return 404 when sessionId does not match any active session', async () => { + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + (transport as any).connections.set('real-session-id', mockTransport); + + const mockReq = { body: { test: 'message' }, query: { sessionId: 'bogus-session-id' } }; + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn() + }; + + // Get the message endpoint handler + const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; + await messageHandler(mockReq, mockRes); + expect(mockRes.status).toHaveBeenCalledWith(404); - expect(mockRes.json).toHaveBeenCalledWith({ error: 'No active SSE connection' }); + expect(mockRes.json).toHaveBeenCalledWith({ error: 'Session not found' }); + expect(mockTransport.handleMessage).not.toHaveBeenCalled(); + }); + + it('should route message only to the transport matching the sessionId', async () => { + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + + const mockTransportA = { handleMessage: jest.fn() }; + const mockTransportB = { handleMessage: jest.fn() }; + (transport as any).connections.set('session-a', mockTransportA); + (transport as any).connections.set('session-b', mockTransportB); + + const mockReq = { body: { jsonrpc: '2.0', method: 'ping', id: 1 }, query: { sessionId: 'session-b' } }; + const mockRes = { + status: jest.fn().mockReturnThis(), + end: jest.fn(), + json: jest.fn() + }; + + // Get the message endpoint handler + const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; + await messageHandler(mockReq, mockRes); + + expect(mockTransportB.handleMessage).toHaveBeenCalledWith({ jsonrpc: '2.0', method: 'ping', id: 1 }); + expect(mockTransportA.handleMessage).not.toHaveBeenCalled(); + expect(mockRes.status).toHaveBeenCalledWith(200); + }); + + it('should reject a request even when another valid session exists', async () => { + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + (transport as any).connections.set('legitimate-session', mockTransport); + + // Attacker sends request with no sessionId + const mockReq = { body: { jsonrpc: '2.0', method: 'tools/call', id: 2 }, query: {} }; + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn() + }; + + const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; + await messageHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockTransport.handleMessage).not.toHaveBeenCalled(); }); it('should handle transport errors', async () => { const transport = new HttpSseTransport(3000, 'localhost', '/sse'); - + // Add a connection that will throw an error mockTransport.handleMessage.mockRejectedValue(new Error('Transport error')); (transport as any).connections.set('test-session', mockTransport); - - const mockReq = { body: { test: 'message' } }; - const mockRes = { + + const mockReq = { body: { test: 'message' }, query: { sessionId: 'test-session' } }; + const mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() }; - + // Get the message endpoint handler const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; await messageHandler(mockReq, mockRes); - + expect(consoleErrorSpy).toHaveBeenCalledWith('Error handling message:', expect.any(Error)); expect(mockRes.status).toHaveBeenCalledWith(500); expect(mockRes.json).toHaveBeenCalledWith({ error: 'Internal server error' }); @@ -418,37 +517,56 @@ describe('HttpSseTransport', () => { }); it('should handle multiple connections and cleanup', () => { + const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' }; + const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' }; + mockSSEServerTransport + .mockImplementationOnce(() => mockTransport1) + .mockImplementationOnce(() => mockTransport2); + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); (transport as any).mcpServer = mockMcpServer; - - // Mock Date.now to return different values for different connections - const originalDateNow = Date.now; - let callCount = 0; - Date.now = jest.fn(() => { - callCount++; - return 1000 + callCount; // Return different timestamps - }); - + // Simulate multiple SSE connections const mockReq1 = { ip: '127.0.0.1', on: jest.fn() }; const mockReq2 = { ip: '127.0.0.2', on: jest.fn() }; - const mockRes1 = {}; - const mockRes2 = {}; - + const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; - sseHandler(mockReq1, mockRes1); - sseHandler(mockReq2, mockRes2); - + sseHandler(mockReq1, {}); + sseHandler(mockReq2, {}); + expect((transport as any).connections.size).toBe(2); - + // Close first connection const closeHandler1 = mockReq1.on.mock.calls.find(call => call[0] === 'close')[1]; closeHandler1(); - + expect((transport as any).connections.size).toBe(1); - - // Restore Date.now - Date.now = originalDateNow; + expect((transport as any).connections.has('session-id-2')).toBe(true); + }); + + it('should not route a message to the first connected client when the request targets a second client', async () => { + const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' }; + const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' }; + mockSSEServerTransport + .mockImplementationOnce(() => mockTransport1) + .mockImplementationOnce(() => mockTransport2); + + const transport = new HttpSseTransport(3000, 'localhost', '/sse'); + (transport as any).mcpServer = mockMcpServer; + + const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1]; + sseHandler({ ip: '127.0.0.1', on: jest.fn() }, {}); + sseHandler({ ip: '127.0.0.2', on: jest.fn() }, {}); + + // POST targeting client 2 — client 1's transport must not be invoked + const mockReq = { body: { jsonrpc: '2.0', method: 'tools/call', id: 99 }, query: { sessionId: 'session-id-2' } }; + const mockRes = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn() }; + + const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1]; + await messageHandler(mockReq, mockRes); + + expect(mockTransport2.handleMessage).toHaveBeenCalledWith({ jsonrpc: '2.0', method: 'tools/call', id: 99 }); + expect(mockTransport1.handleMessage).not.toHaveBeenCalled(); }); }); }); diff --git a/packages/registry/src/tokens/index.ts b/packages/registry/src/tokens/index.ts index 6bb9c8e57..580040e39 100644 --- a/packages/registry/src/tokens/index.ts +++ b/packages/registry/src/tokens/index.ts @@ -56,7 +56,7 @@ type SeiTokens = { }; /** - * A constant that maps each Sei networks to its respective tokens, imported from the community ran [assetlist](https://github.com/Sei-Public-Goods/sei-assetlist). + * A constant that maps each Sei networks to its respective tokens, imported from the community ran [assetlist](https://github.com/Seitrace/sei-assetlist). * * @remarks * **Important**: This token list is community-driven and subject to change.