diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index b86c8aaa..3a979e21 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -1,19 +1,48 @@ """BedrockAgentCore Runtime Package. This package contains the core runtime components for Bedrock AgentCore applications: -- BedrockAgentCoreApp: Main application class +- BedrockAgentCoreApp: Main application class for HTTP protocol +- BedrockAgentCoreA2AApp: Application class for A2A (Agent-to-Agent) protocol - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context +- AgentCard, AgentSkill: A2A protocol metadata models """ +from .a2a_app import BedrockAgentCoreA2AApp +from .a2a_models import ( + A2A_DEFAULT_PORT, + A2AArtifact, + A2AMessage, + A2AMessagePart, + AgentCard, + AgentSkill, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + build_runtime_url, +) from .agent_core_runtime_client import AgentCoreRuntimeClient from .app import BedrockAgentCoreApp from .context import BedrockAgentCoreContext, RequestContext from .models import PingStatus __all__ = [ + # HTTP Protocol "AgentCoreRuntimeClient", "BedrockAgentCoreApp", + # A2A Protocol + "BedrockAgentCoreA2AApp", + "AgentCard", + "AgentSkill", + "A2AMessage", + "A2AMessagePart", + "A2AArtifact", + "JsonRpcRequest", + "JsonRpcResponse", + "JsonRpcErrorCode", + "A2A_DEFAULT_PORT", + "build_runtime_url", + # Common "RequestContext", "BedrockAgentCoreContext", "PingStatus", diff --git a/src/bedrock_agentcore/runtime/a2a_app.py b/src/bedrock_agentcore/runtime/a2a_app.py new file mode 100644 index 00000000..1ffb8036 --- /dev/null +++ b/src/bedrock_agentcore/runtime/a2a_app.py @@ -0,0 +1,311 @@ +"""Bedrock AgentCore A2A application implementation. + +Provides a Starlette-based web server for A2A (Agent-to-Agent) protocol communication. +""" + +import inspect +import json +import os +import time +from collections.abc import Sequence +from typing import Any, Callable, Dict, Optional + +from starlette.middleware import Middleware +from starlette.responses import JSONResponse, StreamingResponse +from starlette.routing import Route +from starlette.types import Lifespan + +from .a2a_models import ( + A2A_DEFAULT_PORT, + AgentCard, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, +) +from .base_app import _BaseAgentCoreApp, _BaseRequestContextFormatter +from .utils import convert_complex_objects + + +class A2ARequestContextFormatter(_BaseRequestContextFormatter): + """Formatter including request and session IDs for A2A applications.""" + + extra_fields = {"protocol": "A2A"} + + +class BedrockAgentCoreA2AApp(_BaseAgentCoreApp): + """Bedrock AgentCore A2A application class for agent-to-agent communication. + + This class implements the A2A protocol contract for AgentCore Runtime, + supporting JSON-RPC 2.0 messaging and agent discovery via Agent Cards. + + Example: + ```python + from bedrock_agentcore.runtime import BedrockAgentCoreA2AApp, AgentCard, AgentSkill + + agent_card = AgentCard( + name="Calculator Agent", + description="A calculator agent", + skills=[AgentSkill(id="calc", name="Calculator", description="Math ops")] + ) + + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.entrypoint + def handle_message(request, context): + # Process JSON-RPC request + message = request.params["message"] + user_text = message["parts"][0]["text"] + + # Return result (will be wrapped in JSON-RPC response) + return { + "artifacts": [{ + "artifactId": str(uuid.uuid4()), + "name": "response", + "parts": [{"kind": "text", "text": f"Result: {user_text}"}] + }] + } + + app.run() # Runs on port 9000 + ``` + """ + + _default_port = A2A_DEFAULT_PORT + + def __init__( + self, + agent_card: AgentCard, + debug: bool = False, + lifespan: Optional[Lifespan] = None, + middleware: Sequence[Middleware] | None = None, + ): + """Initialize Bedrock AgentCore A2A application. + + Args: + agent_card: AgentCard containing agent metadata for discovery + debug: Enable debug mode for verbose logging (default: False) + lifespan: Optional lifespan context manager for startup/shutdown + middleware: Optional sequence of Starlette Middleware objects + """ + self.agent_card = agent_card + + routes = [ + Route("/", self._handle_jsonrpc, methods=["POST"]), + Route("/.well-known/agent-card.json", self._handle_agent_card, methods=["GET"]), + Route("/ping", self._handle_ping, methods=["GET"]), + ] + super().__init__( + routes=routes, + debug=debug, + lifespan=lifespan, + middleware=middleware, + logger_name="bedrock_agentcore.a2a_app", + log_formatter=A2ARequestContextFormatter(), + ) + + def _get_runtime_url(self, request=None) -> Optional[str]: + """Get the runtime URL from environment or current request. + + Returns: + The runtime URL if set, None otherwise. + """ + runtime_url = os.environ.get("AGENTCORE_RUNTIME_URL") + if runtime_url: + return runtime_url + + if request is not None and getattr(request, "base_url", None): + return str(request.base_url) + + return None + + async def _handle_jsonrpc(self, request): + """Handle JSON-RPC 2.0 requests at root endpoint.""" + request_context = self._build_request_context(request) + start_time = time.time() + body = None + + try: + body = await request.json() + if not isinstance(body, dict): + return self._jsonrpc_error_response( + None, + JsonRpcErrorCode.INVALID_REQUEST, + "Invalid request object", + ) + + self.logger.debug("Processing JSON-RPC request: %s", body.get("method", "unknown")) + + # Validate JSON-RPC format + if body.get("jsonrpc") != "2.0": + return self._jsonrpc_error_response( + body.get("id"), + JsonRpcErrorCode.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + + method = body.get("method") + if not method: + return self._jsonrpc_error_response( + body.get("id"), + JsonRpcErrorCode.INVALID_REQUEST, + "Missing method", + ) + + jsonrpc_request = JsonRpcRequest.from_dict(body) + + handler = self.handlers.get("main") + if not handler: + self.logger.error("No entrypoint defined") + return self._jsonrpc_error_response( + jsonrpc_request.id, + JsonRpcErrorCode.INTERNAL_ERROR, + "No entrypoint defined", + ) + + takes_context = self._takes_context(handler) + + self.logger.debug("Invoking handler for method: %s", method) + result = await self._invoke_handler(handler, request_context, takes_context, jsonrpc_request) + + duration = time.time() - start_time + + # Handle streaming responses + if inspect.isasyncgen(result): + self.logger.info("Returning streaming response (%.3fs)", duration) + return StreamingResponse( + self._stream_jsonrpc_response(result, jsonrpc_request.id), + media_type="text/event-stream", + ) + elif inspect.isgenerator(result): + self.logger.info("Returning streaming response (sync generator) (%.3fs)", duration) + return StreamingResponse( + self._sync_stream_jsonrpc_response(result, jsonrpc_request.id), + media_type="text/event-stream", + ) + + # Non-streaming response + self.logger.info("Request completed successfully (%.3fs)", duration) + response = JsonRpcResponse.success(jsonrpc_request.id, self._convert_to_serializable(result)) + return JSONResponse(response.to_dict()) + + except json.JSONDecodeError as e: + duration = time.time() - start_time + self.logger.warning("Invalid JSON in request (%.3fs): %s", duration, e) + return self._jsonrpc_error_response( + None, + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {str(e)}", + ) + except Exception as e: + duration = time.time() - start_time + self.logger.exception("Request failed (%.3fs)", duration) + return self._jsonrpc_error_response( + body.get("id") if body is not None else None, + JsonRpcErrorCode.INTERNAL_ERROR, + "Internal error", + ) + + def _jsonrpc_error_response( + self, + request_id: Optional[str], + code: int, + message: str, + data: Optional[Any] = None, + ) -> JSONResponse: + """Create a JSON-RPC error response.""" + response = JsonRpcResponse.error_response(request_id, code, message, data) + return JSONResponse(response.to_dict()) + + async def _stream_jsonrpc_response(self, generator, request_id): + """Wrap async generator for SSE streaming with JSON-RPC format.""" + try: + async for value in generator: + # Wrap each chunk in JSON-RPC format + chunk_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": value, + } + yield self._to_sse(chunk_response) + except Exception as e: + self.logger.exception("Error in async streaming") + error_response = JsonRpcResponse.error_response( + request_id, + JsonRpcErrorCode.INTERNAL_ERROR, + "Internal error", + ) + yield self._to_sse(error_response.to_dict()) + + def _sync_stream_jsonrpc_response(self, generator, request_id): + """Wrap sync generator for SSE streaming with JSON-RPC format.""" + try: + for value in generator: + chunk_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": value, + } + yield self._to_sse(chunk_response) + except Exception as e: + self.logger.exception("Error in sync streaming") + error_response = JsonRpcResponse.error_response( + request_id, + JsonRpcErrorCode.INTERNAL_ERROR, + "Internal error", + ) + yield self._to_sse(error_response.to_dict()) + + def _to_sse(self, data: Any) -> bytes: + """Convert data to SSE format.""" + json_string = self._safe_serialize_to_json_string(data) + return f"data: {json_string}\n\n".encode("utf-8") + + def _convert_to_serializable(self, obj: Any) -> Any: + """Convert A2A helper models and common Python objects to JSON-safe payloads.""" + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return self._convert_to_serializable(obj.to_dict()) + + if isinstance(obj, dict): + return {key: self._convert_to_serializable(value) for key, value in obj.items()} + + if isinstance(obj, (list, tuple)): + return [self._convert_to_serializable(value) for value in obj] + + if isinstance(obj, set): + return [self._convert_to_serializable(value) for value in obj] + + return convert_complex_objects(obj) + + def _safe_serialize_to_json_string(self, obj: Any) -> str: + """Safely serialize streaming payloads to JSON, with A2A model support.""" + try: + return json.dumps(obj, ensure_ascii=False) + except (TypeError, ValueError, UnicodeEncodeError): + try: + return json.dumps(self._convert_to_serializable(obj), ensure_ascii=False) + except Exception: + return json.dumps(str(obj), ensure_ascii=False) + + def _handle_agent_card(self, request): + """Handle GET /.well-known/agent-card.json endpoint.""" + try: + runtime_url = self._get_runtime_url(request) + card_dict = self.agent_card.to_dict(url=runtime_url) + + self.logger.debug("Serving Agent Card: %s", self.agent_card.name) + return JSONResponse(card_dict) + except Exception as e: + self.logger.exception("Failed to serve Agent Card") + return JSONResponse({"error": "Internal error"}, status_code=500) + + def run(self, port: Optional[int] = None, host: Optional[str] = None, **kwargs): + """Start the Bedrock AgentCore A2A server. + + Args: + port: Port to serve on, defaults to 9000 (A2A standard) + host: Host to bind to, auto-detected if None + **kwargs: Additional arguments passed to uvicorn.run() + """ + if port is None: + port = self._default_port + self.logger.info("Starting A2A server on port %d", port) + super().run(port=port, host=host, **kwargs) diff --git a/src/bedrock_agentcore/runtime/a2a_models.py b/src/bedrock_agentcore/runtime/a2a_models.py new file mode 100644 index 00000000..37f22566 --- /dev/null +++ b/src/bedrock_agentcore/runtime/a2a_models.py @@ -0,0 +1,295 @@ +"""Models for Bedrock AgentCore A2A runtime. + +Contains data models for A2A protocol including Agent Card, JSON-RPC 2.0 messages, +and related types. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Union +from urllib.parse import quote + + +class JsonRpcErrorCode(int, Enum): + """Standard JSON-RPC 2.0 error codes and A2A-specific error codes.""" + + # Standard JSON-RPC 2.0 errors + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + # AgentCore-specific error codes (per AWS Bedrock AgentCore documentation) + RESOURCE_NOT_FOUND = -32501 + VALIDATION_ERROR = -32052 + THROTTLING = -32053 + RESOURCE_CONFLICT = -32054 + RUNTIME_CLIENT_ERROR = -32055 + + +@dataclass +class AgentSkill: + """A2A Agent Skill definition. + + Skills describe specific capabilities that the agent can perform. + """ + + id: str + name: str + description: str + tags: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "tags": self.tags, + } + + +@dataclass +class AgentCard: + """A2A Agent Card metadata. + + Agent Cards describe an agent's identity, capabilities, and how to communicate with it. + This metadata is served at /.well-known/agent-card.json endpoint. + """ + + name: str + description: str + version: str = "1.0.0" + protocol_version: str = "0.3.0" + preferred_transport: str = "JSONRPC" + capabilities: Dict[str, Any] = field(default_factory=lambda: {"streaming": True}) + default_input_modes: List[str] = field(default_factory=lambda: ["text"]) + default_output_modes: List[str] = field(default_factory=lambda: ["text"]) + skills: List[AgentSkill] = field(default_factory=list) + + def to_dict(self, url: Optional[str] = None) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization. + + Args: + url: The URL where this agent is accessible. If not provided, + the 'url' field will be omitted from the output. + + Returns: + Dictionary representation of the Agent Card. + """ + result = { + "name": self.name, + "description": self.description, + "version": self.version, + "protocolVersion": self.protocol_version, + "preferredTransport": self.preferred_transport, + "capabilities": self.capabilities, + "defaultInputModes": self.default_input_modes, + "defaultOutputModes": self.default_output_modes, + "skills": [skill.to_dict() for skill in self.skills], + } + if url: + result["url"] = url + return result + + +@dataclass +class JsonRpcRequest: + """JSON-RPC 2.0 Request object.""" + + method: str + id: Optional[Union[str, int]] = None + params: Optional[Dict[str, Any]] = None + jsonrpc: str = "2.0" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JsonRpcRequest": + """Create from dictionary.""" + return cls( + jsonrpc=data.get("jsonrpc", "2.0"), + id=data.get("id"), + method=data.get("method", ""), + params=data.get("params"), + ) + + +@dataclass +class JsonRpcError: + """JSON-RPC 2.0 Error object.""" + + code: int + message: str + data: Optional[Any] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result = {"code": self.code, "message": self.message} + if self.data is not None: + result["data"] = self.data + return result + + +@dataclass +class JsonRpcResponse: + """JSON-RPC 2.0 Response object.""" + + id: Optional[Union[str, int]] + result: Optional[Any] = None + error: Optional[JsonRpcError] = None + jsonrpc: str = "2.0" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + response = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + response["error"] = self.error.to_dict() + else: + response["result"] = self.result + return response + + @classmethod + def success(cls, id: Optional[Union[str, int]], result: Any) -> "JsonRpcResponse": + """Create a success response.""" + return cls(id=id, result=result) + + @classmethod + def error_response( + cls, + id: Optional[Union[str, int]], + code: int, + message: str, + data: Optional[Any] = None, + ) -> "JsonRpcResponse": + """Create an error response.""" + return cls(id=id, error=JsonRpcError(code=code, message=message, data=data)) + + +@dataclass +class A2AMessagePart: + """A2A message part (text, file, data, etc.).""" + + kind: str # "text", "file", "data", etc. + text: Optional[str] = None + file: Optional[Dict[str, Any]] = None + data: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: Dict[str, Any] = {"kind": self.kind} + if self.text is not None: + result["text"] = self.text + if self.file is not None: + result["file"] = self.file + if self.data is not None: + result["data"] = self.data + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "A2AMessagePart": + """Create from dictionary.""" + return cls( + kind=data.get("kind", "text"), + text=data.get("text"), + file=data.get("file"), + data=data.get("data"), + ) + + +@dataclass +class A2AMessage: + """A2A protocol message.""" + + role: str # "user", "agent" + parts: List[A2AMessagePart] + message_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "role": self.role, + "parts": [part.to_dict() for part in self.parts], + } + if self.message_id: + result["messageId"] = self.message_id + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "A2AMessage": + """Create from dictionary.""" + parts = [A2AMessagePart.from_dict(p) for p in data.get("parts", [])] + return cls( + role=data.get("role", "user"), + parts=parts, + message_id=data.get("messageId"), + ) + + def get_text(self) -> str: + """Extract text content from message parts.""" + texts = [] + for part in self.parts: + if part.kind == "text" and part.text: + texts.append(part.text) + return "\n".join(texts) + + +@dataclass +class A2AArtifact: + """A2A protocol artifact (response content).""" + + artifact_id: str + name: str + parts: List[A2AMessagePart] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "artifactId": self.artifact_id, + "name": self.name, + "parts": [part.to_dict() for part in self.parts], + } + + @classmethod + def from_text(cls, artifact_id: str, name: str, text: str) -> "A2AArtifact": + """Create a text artifact.""" + return cls( + artifact_id=artifact_id, + name=name, + parts=[A2AMessagePart(kind="text", text=text)], + ) + + +def build_runtime_url(agent_arn: str) -> str: + """Build the AgentCore Runtime URL from an agent ARN. + + The region is automatically extracted from the ARN. + + Args: + agent_arn: The ARN of the agent runtime + (e.g., arn:aws:bedrock:us-west-2:123456789012:agent-runtime/xxx) + + Returns: + The full runtime URL with properly encoded ARN + + Raises: + ValueError: If the ARN format is invalid and region cannot be parsed + """ + # Parse region from ARN (format: arn:partition:service:region:account:resource) + arn_parts = agent_arn.split(":") + if len(arn_parts) < 4 or not arn_parts[3]: + raise ValueError(f"Cannot parse region from ARN: {agent_arn}") + region = arn_parts[3] + + # URL encode the ARN (safe='' means encode all special characters) + escaped_arn = quote(agent_arn, safe="") + return f"https://bedrock-agentcore.{region}.amazonaws.com/runtimes/{escaped_arn}/invocations/" + + +# A2A Protocol Methods +A2A_METHOD_MESSAGE_SEND = "message/send" +A2A_METHOD_MESSAGE_STREAM = "message/stream" +A2A_METHOD_TASKS_GET = "tasks/get" +A2A_METHOD_TASKS_CANCEL = "tasks/cancel" + +# Default A2A port for AgentCore Runtime +A2A_DEFAULT_PORT = 9000 diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index d5754267..bc9857d3 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -1,34 +1,22 @@ -"""Bedrock AgentCore base implementation. +"""Bedrock AgentCore HTTP implementation. Provides a Starlette-based web server that wraps user functions as HTTP endpoints. """ -import asyncio -import contextvars import inspect import json -import logging -import threading import time -import uuid from collections.abc import Sequence from typing import Any, Callable, Dict, Optional -from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.types import Lifespan from starlette.websockets import WebSocket, WebSocketDisconnect -from .context import BedrockAgentCoreContext, RequestContext +from .base_app import _BaseAgentCoreApp, _BaseRequestContextFormatter from .models import ( - ACCESS_TOKEN_HEADER, - AUTHORIZATION_HEADER, - CUSTOM_HEADER_PREFIX, - OAUTH2_CALLBACK_URL_HEADER, - REQUEST_ID_HEADER, - SESSION_HEADER, TASK_ACTION_CLEAR_FORCED_STATUS, TASK_ACTION_FORCE_BUSY, TASK_ACTION_FORCE_HEALTHY, @@ -36,45 +24,18 @@ TASK_ACTION_PING_STATUS, PingStatus, ) -from .utils import convert_complex_objects -class RequestContextFormatter(logging.Formatter): - """Formatter including request and session IDs.""" +class RequestContextFormatter(_BaseRequestContextFormatter): + """Formatter including request and session IDs for HTTP applications.""" - def format(self, record): - """Format log record as AWS Lambda JSON.""" - import json - from datetime import datetime + pass - log_entry = { - "timestamp": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", - "level": record.levelname, - "message": record.getMessage(), - "logger": record.name, - } - request_id = BedrockAgentCoreContext.get_request_id() - if request_id: - log_entry["requestId"] = request_id +class BedrockAgentCoreApp(_BaseAgentCoreApp): + """Bedrock AgentCore application class for HTTP protocol deployment.""" - session_id = BedrockAgentCoreContext.get_session_id() - if session_id: - log_entry["sessionId"] = session_id - - if record.exc_info: - import traceback - - log_entry["errorType"] = record.exc_info[0].__name__ - log_entry["errorMessage"] = str(record.exc_info[1]) - log_entry["stackTrace"] = traceback.format_exception(*record.exc_info) - log_entry["location"] = f"{record.pathname}:{record.funcName}:{record.lineno}" - - return json.dumps(log_entry, ensure_ascii=False) - - -class BedrockAgentCoreApp(Starlette): - """Bedrock AgentCore application class that extends Starlette for AI agent deployment.""" + _default_port = 8080 def __init__( self, @@ -89,54 +50,21 @@ def __init__( lifespan: Optional lifespan context manager for startup/shutdown middleware: Optional sequence of Starlette Middleware objects (or Middleware(...) entries) """ - self.handlers: Dict[str, Callable] = {} - self._ping_handler: Optional[Callable] = None self._websocket_handler: Optional[Callable] = None - self._active_tasks: Dict[int, Dict[str, Any]] = {} - self._task_counter_lock: threading.Lock = threading.Lock() - self._forced_ping_status: Optional[PingStatus] = None - self._last_status_update_time: float = time.time() routes = [ Route("/invocations", self._handle_invocation, methods=["POST"]), Route("/ping", self._handle_ping, methods=["GET"]), WebSocketRoute("/ws", self._handle_websocket), ] - super().__init__(routes=routes, lifespan=lifespan, middleware=middleware) - self.debug = debug # Set after super().__init__ to avoid override - - self.logger = logging.getLogger("bedrock_agentcore.app") - if not self.logger.handlers: - handler = logging.StreamHandler() - formatter = RequestContextFormatter() - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.setLevel(logging.DEBUG if self.debug else logging.INFO) - - def entrypoint(self, func: Callable) -> Callable: - """Decorator to register a function as the main entrypoint. - - Args: - func: The function to register as entrypoint - - Returns: - The decorated function with added serve method - """ - self.handlers["main"] = func - func.run = lambda port=8080, host=None: self.run(port, host) - return func - - def ping(self, func: Callable) -> Callable: - """Decorator to register a custom ping status handler. - - Args: - func: The function to register as ping status handler - - Returns: - The decorated function - """ - self._ping_handler = func - return func + super().__init__( + routes=routes, + debug=debug, + lifespan=lifespan, + middleware=middleware, + logger_name="bedrock_agentcore.app", + log_formatter=RequestContextFormatter(), + ) def websocket(self, func: Callable) -> Callable: """Decorator to register a WebSocket handler at /ws endpoint. @@ -156,203 +84,6 @@ async def handler(websocket, context): self._websocket_handler = func return func - def async_task(self, func: Callable) -> Callable: - """Decorator to track async tasks for ping status. - - When a function is decorated with @async_task, it will: - - Set ping status to HEALTHY_BUSY while running - - Revert to HEALTHY when complete - """ - if not asyncio.iscoroutinefunction(func): - raise ValueError("@async_task can only be applied to async functions") - - async def wrapper(*args, **kwargs): - task_id = self.add_async_task(func.__name__) - - try: - self.logger.debug("Starting async task: %s", func.__name__) - start_time = time.time() - result = await func(*args, **kwargs) - duration = time.time() - start_time - self.logger.info("Async task completed: %s (%.3fs)", func.__name__, duration) - return result - except Exception: - duration = time.time() - start_time - self.logger.exception("Async task failed: %s (%.3fs)", func.__name__, duration) - raise - finally: - self.complete_async_task(task_id) - - wrapper.__name__ = func.__name__ - return wrapper - - def get_current_ping_status(self) -> PingStatus: - """Get current ping status (forced > custom > automatic).""" - current_status = None - - if self._forced_ping_status is not None: - current_status = self._forced_ping_status - elif self._ping_handler: - try: - result = self._ping_handler() - if isinstance(result, str): - current_status = PingStatus(result) - else: - current_status = result - except Exception as e: - self.logger.warning( - "Custom ping handler failed, falling back to automatic: %s: %s", type(e).__name__, e - ) - - if current_status is None: - current_status = PingStatus.HEALTHY_BUSY if self._active_tasks else PingStatus.HEALTHY - if not hasattr(self, "_last_known_status") or self._last_known_status != current_status: - self._last_known_status = current_status - self._last_status_update_time = time.time() - - return current_status - - def force_ping_status(self, status: PingStatus): - """Force ping status to a specific value.""" - self._forced_ping_status = status - - def clear_forced_ping_status(self): - """Clear forced status and resume automatic.""" - self._forced_ping_status = None - - def get_async_task_info(self) -> Dict[str, Any]: - """Get info about running async tasks.""" - running_jobs = [] - for t in self._active_tasks.values(): - try: - running_jobs.append( - {"name": t.get("name", "unknown"), "duration": time.time() - t.get("start_time", time.time())} - ) - except Exception as e: - self.logger.warning("Caught exception, continuing...: %s", e) - continue - - return {"active_count": len(self._active_tasks), "running_jobs": running_jobs} - - def add_async_task(self, name: str, metadata: Optional[Dict] = None) -> int: - """Register an async task for interactive health tracking. - - This method provides granular control over async task lifecycle, - allowing developers to interactively start tracking tasks for health monitoring. - Use this when you need precise control over when tasks begin and end. - - Args: - name: Human-readable task name for monitoring - metadata: Optional additional task metadata - - Returns: - Task ID for tracking and completion - - Example: - task_id = app.add_async_task("file_processing", {"file": "data.csv"}) - # ... do background work ... - app.complete_async_task(task_id) - """ - with self._task_counter_lock: - task_id = hash(str(uuid.uuid4())) # Generate truly unique hash-based ID - - # Register task start with same structure as @async_task decorator - task_info = {"name": name, "start_time": time.time()} - if metadata: - task_info["metadata"] = metadata - - self._active_tasks[task_id] = task_info - - self.logger.info("Async task started: %s (ID: %s)", name, task_id) - return task_id - - def complete_async_task(self, task_id: int) -> bool: - """Mark an async task as complete for interactive health tracking. - - This method provides granular control over async task lifecycle, - allowing developers to interactively complete tasks for health monitoring. - Call this when your background work finishes. - - Args: - task_id: Task ID returned from add_async_task - - Returns: - True if task was found and completed, False otherwise - - Example: - task_id = app.add_async_task("file_processing") - # ... do background work ... - completed = app.complete_async_task(task_id) - """ - with self._task_counter_lock: - task_info = self._active_tasks.pop(task_id, None) - if task_info: - task_name = task_info.get("name", "unknown") - duration = time.time() - task_info.get("start_time", time.time()) - - self.logger.info("Async task completed: %s (ID: %s, Duration: %.2fs)", task_name, task_id, duration) - return True - else: - self.logger.warning("Attempted to complete unknown task ID: %s", task_id) - return False - - def _build_request_context(self, request) -> RequestContext: - """Build request context and setup all context variables.""" - try: - headers = request.headers - request_id = headers.get(REQUEST_ID_HEADER) - if not request_id: - request_id = str(uuid.uuid4()) - - session_id = headers.get(SESSION_HEADER) - BedrockAgentCoreContext.set_request_context(request_id, session_id) - - agent_identity_token = headers.get(ACCESS_TOKEN_HEADER) - if agent_identity_token: - BedrockAgentCoreContext.set_workload_access_token(agent_identity_token) - - oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER) - if oauth2_callback_url: - BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) - - # Collect relevant request headers (Authorization + Custom headers) - request_headers = {} - - # Add Authorization header if present - authorization_header = headers.get(AUTHORIZATION_HEADER) - if authorization_header is not None: - request_headers[AUTHORIZATION_HEADER] = authorization_header - - # Add custom headers with the specified prefix - for header_name, header_value in headers.items(): - if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()): - request_headers[header_name] = header_value - - # Set in context if any headers were found - if request_headers: - BedrockAgentCoreContext.set_request_headers(request_headers) - - # Get the headers from context to pass to RequestContext - req_headers = BedrockAgentCoreContext.get_request_headers() - - return RequestContext( - session_id=session_id, - request_headers=req_headers, - request=request, # Pass through the Starlette request object - ) - except Exception as e: - self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) - request_id = str(uuid.uuid4()) - BedrockAgentCoreContext.set_request_context(request_id, None) - return RequestContext(session_id=None, request=None) - - def _takes_context(self, handler: Callable) -> bool: - try: - params = list(inspect.signature(handler).parameters.keys()) - return len(params) >= 2 and params[1] == "context" - except Exception: - return False - async def _handle_invocation(self, request): request_context = self._build_request_context(request) @@ -402,15 +133,6 @@ async def _handle_invocation(self, request): self.logger.exception("Invocation failed (%.3fs)", duration) return JSONResponse({"error": str(e)}, status_code=500) - def _handle_ping(self, request): - try: - status = self.get_current_ping_status() - self.logger.debug("Ping request - status: %s", status.value) - return JSONResponse({"status": status.value, "time_of_last_update": int(self._last_status_update_time)}) - except Exception: - self.logger.exception("Ping endpoint failed") - return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())}) - async def _handle_websocket(self, websocket: WebSocket): """Handle WebSocket connections.""" request_context = self._build_request_context(websocket) @@ -434,50 +156,6 @@ async def _handle_websocket(self, websocket: WebSocket): except Exception: pass - def run(self, port: int = 8080, host: Optional[str] = None, **kwargs): - """Start the Bedrock AgentCore server. - - Args: - port: Port to serve on, defaults to 8080 - host: Host to bind to, auto-detected if None - **kwargs: Additional arguments passed to uvicorn.run() - """ - import os - - import uvicorn - - if host is None: - if os.path.exists("/.dockerenv") or os.environ.get("DOCKER_CONTAINER"): - host = "0.0.0.0" # nosec B104 - Docker needs this to expose the port - else: - host = "127.0.0.1" - - # Set default uvicorn parameters, allow kwargs to override - uvicorn_params = { - "host": host, - "port": port, - "access_log": self.debug, - "log_level": "info" if self.debug else "warning", - } - uvicorn_params.update(kwargs) - - uvicorn.run(self, **uvicorn_params) - - async def _invoke_handler(self, handler, request_context, takes_context, payload): - try: - args = (payload, request_context) if takes_context else (payload,) - - if asyncio.iscoroutinefunction(handler): - return await handler(*args) - else: - loop = asyncio.get_event_loop() - ctx = contextvars.copy_context() - return await loop.run_in_executor(None, ctx.run, handler, *args) - except Exception: - handler_name = getattr(handler, "__name__", "unknown") - self.logger.debug("Handler '%s' execution failed", handler_name) - raise - def _handle_task_action(self, payload: dict) -> Optional[JSONResponse]: """Handle task management actions if present in payload.""" action = payload.get("_agent_core_app_action") @@ -538,47 +216,6 @@ async def _stream_with_error_handling(self, generator): } yield self._convert_to_sse(error_event) - def _safe_serialize_to_json_string(self, obj): - """Safely serialize object directly to JSON string with progressive fallback handling. - - This method eliminates double JSON encoding by returning the JSON string directly, - avoiding the test-then-encode pattern that leads to redundant json.dumps() calls. - Used by both streaming and non-streaming responses for consistent behavior. - - Returns: - str: JSON string representation of the object - """ - try: - # First attempt: direct JSON serialization with Unicode support - return json.dumps(obj, ensure_ascii=False) - except (TypeError, ValueError, UnicodeEncodeError): - try: - # Second attempt: convert to serializable dictionaries, then JSON encode the dictionaries - converted_obj = convert_complex_objects(obj) - return json.dumps(converted_obj, ensure_ascii=False) - except Exception: - try: - # Third attempt: convert to string, then JSON encode the string - return json.dumps(str(obj), ensure_ascii=False) - except Exception as e: - # Final fallback: JSON encode error object with ASCII fallback for problematic Unicode - self.logger.warning("Failed to serialize object: %s: %s", type(e).__name__, e) - error_obj = {"error": "Serialization failed", "original_type": type(obj).__name__} - return json.dumps(error_obj, ensure_ascii=False) - - def _convert_to_sse(self, obj) -> bytes: - """Convert object to Server-Sent Events format using safe serialization. - - Args: - obj: Object to convert to SSE format - - Returns: - bytes: SSE-formatted data ready for streaming - """ - json_string = self._safe_serialize_to_json_string(obj) - sse_data = f"data: {json_string}\n\n" - return sse_data.encode("utf-8") - def _sync_stream_with_error_handling(self, generator): """Wrap sync generator to handle errors and convert to SSE format.""" try: diff --git a/src/bedrock_agentcore/runtime/base_app.py b/src/bedrock_agentcore/runtime/base_app.py new file mode 100644 index 00000000..69d4deed --- /dev/null +++ b/src/bedrock_agentcore/runtime/base_app.py @@ -0,0 +1,405 @@ +"""Base application class for Bedrock AgentCore runtime. + +Provides common functionality shared between HTTP and A2A protocol implementations. +""" + +import asyncio +import contextvars +import inspect +import json +import logging +import os +import threading +import time +import uuid +from collections.abc import Sequence +from datetime import datetime, timezone +from typing import Any, Callable, Dict, Optional + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.responses import JSONResponse +from starlette.types import Lifespan + +from .context import BedrockAgentCoreContext, RequestContext +from .models import ( + ACCESS_TOKEN_HEADER, + AUTHORIZATION_HEADER, + CUSTOM_HEADER_PREFIX, + OAUTH2_CALLBACK_URL_HEADER, + REQUEST_ID_HEADER, + SESSION_HEADER, + PingStatus, +) +from .utils import convert_complex_objects + + +class _BaseRequestContextFormatter(logging.Formatter): + """Base log formatter including request and session IDs. + + Subclasses can provide extra_fields to include in every log entry + (e.g., {"protocol": "A2A"}). + """ + + extra_fields: Dict[str, Any] = {} + + def format(self, record): + """Format log record as JSON.""" + log_entry = { + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "level": record.levelname, + "message": record.getMessage(), + "logger": record.name, + } + + log_entry.update(self.extra_fields) + + request_id = BedrockAgentCoreContext.get_request_id() + if request_id: + log_entry["requestId"] = request_id + + session_id = BedrockAgentCoreContext.get_session_id() + if session_id: + log_entry["sessionId"] = session_id + + if record.exc_info: + import traceback + + log_entry["errorType"] = record.exc_info[0].__name__ + log_entry["errorMessage"] = str(record.exc_info[1]) + log_entry["stackTrace"] = traceback.format_exception(*record.exc_info) + log_entry["location"] = f"{record.pathname}:{record.funcName}:{record.lineno}" + + return json.dumps(log_entry, ensure_ascii=False) + + +class _BaseAgentCoreApp(Starlette): + """Base class for Bedrock AgentCore applications. + + Provides shared functionality for HTTP and A2A protocol implementations: + - Handler registration (entrypoint, ping, async_task decorators) + - Ping/health check management + - Async task lifecycle tracking + - Request context building + - Safe JSON serialization + - Server startup (uvicorn) + """ + + _default_port: int = 8080 + + def __init__( + self, + routes: list, + debug: bool = False, + lifespan: Optional[Lifespan] = None, + middleware: Sequence[Middleware] | None = None, + logger_name: str = "bedrock_agentcore.base_app", + log_formatter: Optional[logging.Formatter] = None, + ): + self.handlers: Dict[str, Callable] = {} + self._ping_handler: Optional[Callable] = None + self._active_tasks: Dict[int, Dict[str, Any]] = {} + self._task_counter_lock: threading.Lock = threading.Lock() + self._forced_ping_status: Optional[PingStatus] = None + self._last_status_update_time: float = time.time() + + super().__init__(routes=routes, lifespan=lifespan, middleware=middleware) + self.debug = debug + + self.logger = logging.getLogger(logger_name) + if not self.logger.handlers: + handler = logging.StreamHandler() + if log_formatter: + handler.setFormatter(log_formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if self.debug else logging.INFO) + + def entrypoint(self, func: Callable) -> Callable: + """Decorator to register a function as the main entrypoint. + + Args: + func: The function to register as entrypoint + + Returns: + The decorated function with added run method + """ + self.handlers["main"] = func + default_port = self._default_port + func.run = lambda port=default_port, host=None: self.run(port, host) + return func + + def ping(self, func: Callable) -> Callable: + """Decorator to register a custom ping status handler. + + Args: + func: The function to register as ping status handler + + Returns: + The decorated function + """ + self._ping_handler = func + return func + + def async_task(self, func: Callable) -> Callable: + """Decorator to track async tasks for ping status. + + When a function is decorated with @async_task, it will: + - Set ping status to HEALTHY_BUSY while running + - Revert to HEALTHY when complete + """ + if not asyncio.iscoroutinefunction(func): + raise ValueError("@async_task can only be applied to async functions") + + async def wrapper(*args, **kwargs): + task_id = self.add_async_task(func.__name__) + + try: + self.logger.debug("Starting async task: %s", func.__name__) + start_time = time.time() + result = await func(*args, **kwargs) + duration = time.time() - start_time + self.logger.info("Async task completed: %s (%.3fs)", func.__name__, duration) + return result + except Exception: + duration = time.time() - start_time + self.logger.exception("Async task failed: %s (%.3fs)", func.__name__, duration) + raise + finally: + self.complete_async_task(task_id) + + wrapper.__name__ = func.__name__ + return wrapper + + def get_current_ping_status(self) -> PingStatus: + """Get current ping status (forced > custom > automatic).""" + current_status = None + + if self._forced_ping_status is not None: + current_status = self._forced_ping_status + elif self._ping_handler: + try: + result = self._ping_handler() + if isinstance(result, str): + current_status = PingStatus(result) + else: + current_status = result + except Exception as e: + self.logger.warning( + "Custom ping handler failed, falling back to automatic: %s: %s", type(e).__name__, e + ) + + if current_status is None: + current_status = PingStatus.HEALTHY_BUSY if self._active_tasks else PingStatus.HEALTHY + + if not hasattr(self, "_last_known_status") or self._last_known_status != current_status: + self._last_known_status = current_status + self._last_status_update_time = time.time() + + return current_status + + def force_ping_status(self, status: PingStatus): + """Force ping status to a specific value.""" + self._forced_ping_status = status + + def clear_forced_ping_status(self): + """Clear forced status and resume automatic.""" + self._forced_ping_status = None + + def get_async_task_info(self) -> Dict[str, Any]: + """Get info about running async tasks.""" + running_jobs = [] + for t in self._active_tasks.values(): + try: + running_jobs.append( + {"name": t.get("name", "unknown"), "duration": time.time() - t.get("start_time", time.time())} + ) + except Exception as e: + self.logger.warning("Caught exception, continuing...: %s", e) + continue + + return {"active_count": len(self._active_tasks), "running_jobs": running_jobs} + + def add_async_task(self, name: str, metadata: Optional[Dict] = None) -> int: + """Register an async task for interactive health tracking. + + Args: + name: Human-readable task name for monitoring + metadata: Optional additional task metadata + + Returns: + Task ID for tracking and completion + """ + with self._task_counter_lock: + task_id = hash(str(uuid.uuid4())) + + task_info = {"name": name, "start_time": time.time()} + if metadata: + task_info["metadata"] = metadata + + self._active_tasks[task_id] = task_info + + self.logger.info("Async task started: %s (ID: %s)", name, task_id) + return task_id + + def complete_async_task(self, task_id: int) -> bool: + """Mark an async task as complete for interactive health tracking. + + Args: + task_id: Task ID returned from add_async_task + + Returns: + True if task was found and completed, False otherwise + """ + with self._task_counter_lock: + task_info = self._active_tasks.pop(task_id, None) + if task_info: + task_name = task_info.get("name", "unknown") + duration = time.time() - task_info.get("start_time", time.time()) + + self.logger.info("Async task completed: %s (ID: %s, Duration: %.2fs)", task_name, task_id, duration) + return True + else: + self.logger.warning("Attempted to complete unknown task ID: %s", task_id) + return False + + def _build_request_context(self, request) -> RequestContext: + """Build request context and setup all context variables.""" + try: + headers = request.headers + request_id = headers.get(REQUEST_ID_HEADER) + if not request_id: + request_id = str(uuid.uuid4()) + + session_id = headers.get(SESSION_HEADER) + BedrockAgentCoreContext.set_request_context(request_id, session_id) + + agent_identity_token = headers.get(ACCESS_TOKEN_HEADER) + if agent_identity_token: + BedrockAgentCoreContext.set_workload_access_token(agent_identity_token) + + oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER) + if oauth2_callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) + + # Collect relevant request headers + request_headers = {} + + authorization_header = headers.get(AUTHORIZATION_HEADER) + if authorization_header is not None: + request_headers[AUTHORIZATION_HEADER] = authorization_header + + for header_name, header_value in headers.items(): + if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()): + request_headers[header_name] = header_value + + if request_headers: + BedrockAgentCoreContext.set_request_headers(request_headers) + + req_headers = BedrockAgentCoreContext.get_request_headers() + + return RequestContext( + session_id=session_id, + request_headers=req_headers, + request=request, + ) + except Exception as e: + self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) + request_id = str(uuid.uuid4()) + BedrockAgentCoreContext.set_request_context(request_id, None) + return RequestContext(session_id=None, request=None) + + def _takes_context(self, handler: Callable) -> bool: + """Check if handler accepts context parameter.""" + try: + params = list(inspect.signature(handler).parameters.keys()) + return len(params) >= 2 and params[1] == "context" + except Exception: + return False + + async def _invoke_handler(self, handler, request_context, takes_context, payload): + """Invoke the handler with appropriate arguments.""" + try: + args = (payload, request_context) if takes_context else (payload,) + + if asyncio.iscoroutinefunction(handler): + return await handler(*args) + else: + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + return await loop.run_in_executor(None, ctx.run, handler, *args) + except Exception: + handler_name = getattr(handler, "__name__", "unknown") + self.logger.debug("Handler '%s' execution failed", handler_name) + raise + + def _handle_ping(self, request): + """Handle GET /ping health check endpoint.""" + try: + status = self.get_current_ping_status() + self.logger.debug("Ping request - status: %s", status.value) + return JSONResponse({"status": status.value, "time_of_last_update": int(self._last_status_update_time)}) + except Exception: + self.logger.exception("Ping endpoint failed") + return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())}) + + def _safe_serialize_to_json_string(self, obj): + """Safely serialize object to JSON string with progressive fallback. + + Returns: + str: JSON string representation of the object + """ + try: + return json.dumps(obj, ensure_ascii=False) + except (TypeError, ValueError, UnicodeEncodeError): + try: + converted_obj = convert_complex_objects(obj) + return json.dumps(converted_obj, ensure_ascii=False) + except Exception: + try: + return json.dumps(str(obj), ensure_ascii=False) + except Exception as e: + self.logger.warning("Failed to serialize object: %s: %s", type(e).__name__, e) + error_obj = {"error": "Serialization failed", "original_type": type(obj).__name__} + return json.dumps(error_obj, ensure_ascii=False) + + def _convert_to_sse(self, obj) -> bytes: + """Convert object to Server-Sent Events format. + + Args: + obj: Object to convert to SSE format + + Returns: + bytes: SSE-formatted data ready for streaming + """ + json_string = self._safe_serialize_to_json_string(obj) + return f"data: {json_string}\n\n".encode("utf-8") + + def run(self, port: Optional[int] = None, host: Optional[str] = None, **kwargs): + """Start the Bedrock AgentCore server. + + Args: + port: Port to serve on (defaults to protocol-specific port) + host: Host to bind to, auto-detected if None + **kwargs: Additional arguments passed to uvicorn.run() + """ + import uvicorn + + if port is None: + port = self._default_port + + if host is None: + if os.path.exists("/.dockerenv") or os.environ.get("DOCKER_CONTAINER"): + host = "0.0.0.0" # nosec B104 - Docker needs this to expose the port + else: + host = "127.0.0.1" + + uvicorn_params = { + "host": host, + "port": port, + "access_log": self.debug, + "log_level": "info" if self.debug else "warning", + } + uvicorn_params.update(kwargs) + + self.logger.info("Starting server on %s:%d", host, port) + uvicorn.run(self, **uvicorn_params) diff --git a/tests/bedrock_agentcore/runtime/test_a2a_app.py b/tests/bedrock_agentcore/runtime/test_a2a_app.py new file mode 100644 index 00000000..6d531981 --- /dev/null +++ b/tests/bedrock_agentcore/runtime/test_a2a_app.py @@ -0,0 +1,1101 @@ +"""Tests for BedrockAgentCoreA2AApp.""" + +import asyncio +import contextlib +import json +import logging +import os +import uuid +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest +from starlette.testclient import TestClient + +from bedrock_agentcore.runtime import ( + A2AArtifact, + AgentCard, + AgentSkill, + BedrockAgentCoreA2AApp, +) +from bedrock_agentcore.runtime.a2a_app import A2ARequestContextFormatter +from bedrock_agentcore.runtime.models import PingStatus + + +@pytest.fixture +def agent_card(): + """Create a test AgentCard.""" + return AgentCard( + name="Test Agent", + description="A test agent for unit testing", + skills=[ + AgentSkill(id="test", name="Test Skill", description="A test skill"), + ], + ) + + +@pytest.fixture +def app(agent_card): + """Create a test A2A app.""" + return BedrockAgentCoreA2AApp(agent_card=agent_card) + + +class TestBedrockAgentCoreA2AAppInitialization: + def test_basic_initialization(self, agent_card): + """Test basic app initialization.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + assert app.agent_card == agent_card + assert app.handlers == {} + assert app.debug is False + + def test_initialization_with_debug(self, agent_card): + """Test app initialization with debug mode.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card, debug=True) + assert app.debug is True + + def test_routes_registered(self, agent_card): + """Test that required routes are registered.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + route_paths = [route.path for route in app.routes] + assert "/" in route_paths + assert "/.well-known/agent-card.json" in route_paths + assert "/ping" in route_paths + + +class TestAgentCardEndpoint: + def test_agent_card_endpoint(self, app, agent_card): + """Test GET /.well-known/agent-card.json returns agent card.""" + client = TestClient(app) + response = client.get("/.well-known/agent-card.json") + + assert response.status_code == 200 + data = response.json() + assert data["name"] == agent_card.name + assert data["description"] == agent_card.description + assert data["protocolVersion"] == agent_card.protocol_version + assert data["url"] == "http://testserver/" + assert len(data["skills"]) == 1 + assert data["skills"][0]["id"] == "test" + + def test_agent_card_with_runtime_url(self, agent_card): + """Test agent card includes URL when AGENTCORE_RUNTIME_URL is set.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + with patch.dict(os.environ, {"AGENTCORE_RUNTIME_URL": "https://example.com/agent"}): + client = TestClient(app) + response = client.get("/.well-known/agent-card.json") + + assert response.status_code == 200 + data = response.json() + assert data["url"] == "https://example.com/agent" + + +class TestPingEndpoint: + def test_ping_endpoint(self, app): + """Test GET /ping returns healthy status.""" + client = TestClient(app) + response = client.get("/ping") + + assert response.status_code == 200 + data = response.json() + assert data["status"] in ["Healthy", "HEALTHY"] + assert "time_of_last_update" in data + + def test_custom_ping_handler(self, agent_card): + """Test custom ping handler.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.ping + def custom_ping(): + return "HealthyBusy" + + client = TestClient(app) + response = client.get("/ping") + + assert response.status_code == 200 + data = response.json() + assert data["status"] in ["HealthyBusy", "HEALTHY_BUSY"] + + +class TestEntrypointDecorator: + def test_entrypoint_decorator(self, app): + """Test @app.entrypoint registers handler.""" + + @app.entrypoint + def handler(request, context): + return {"result": "success"} + + assert "main" in app.handlers + assert app.handlers["main"] == handler + assert hasattr(handler, "run") + + def test_entrypoint_without_context(self, app): + """Test entrypoint handler without context parameter.""" + + @app.entrypoint + def handler(request): + return {"result": request.method} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": {}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert data["result"]["result"] == "message/send" + + +class TestJsonRpcHandling: + def test_valid_jsonrpc_request(self, app): + """Test valid JSON-RPC request.""" + + @app.entrypoint + def handler(request, context): + return {"artifacts": [{"artifactId": "art-1", "name": "response", "parts": []}]} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + } + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert "result" in data + assert "artifacts" in data["result"] + + def test_dataclass_response_is_serialized(self, app): + """Test A2A helper models are serialized in JSON-RPC responses.""" + + @app.entrypoint + def handler(request, context): + return {"artifacts": [A2AArtifact.from_text("art-1", "response", "Hello")]} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": {}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["result"]["artifacts"][0]["artifactId"] == "art-1" + assert data["result"]["artifacts"][0]["parts"][0]["text"] == "Hello" + + def test_invalid_jsonrpc_version(self, app): + """Test invalid JSON-RPC version returns error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "1.0", # Invalid version + "id": "req-001", + "method": "test", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32600 # Invalid request + + def test_missing_method(self, app): + """Test missing method returns error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + # Missing method + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32600 # Invalid request + + def test_non_object_request_returns_invalid_request(self, app): + """Test non-object JSON-RPC payload returns invalid request.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + json=[], + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32600 # Invalid request + assert data["id"] is None + + def test_no_entrypoint_defined(self, app): + """Test error when no entrypoint is defined.""" + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 # Internal error + + def test_invalid_json(self, app): + """Test invalid JSON returns parse error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + content="not valid json", + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32700 # Parse error + + def test_handler_exception(self, app): + """Test handler exception returns internal error.""" + + @app.entrypoint + def handler(request, context): + raise ValueError("Test error") + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 # Internal error + assert data["error"]["message"] == "Internal error" + + +class TestAsyncHandler: + def test_async_handler(self, app): + """Test async handler.""" + + @app.entrypoint + async def handler(request, context): + await asyncio.sleep(0.01) + return {"result": "async success"} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["result"]["result"] == "async success" + + +class TestStreamingResponse: + def test_async_generator_response(self, app): + """Test async generator for streaming response.""" + + @app.entrypoint + async def handler(request, context): + async def generate(): + yield {"chunk": 1} + yield {"chunk": 2} + yield {"chunk": 3} + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/stream", + }, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Parse SSE events + events = response.text.split("\n\n") + events = [e for e in events if e.strip()] + + assert len(events) == 3 + for i, event in enumerate(events, 1): + assert event.startswith("data: ") + data = json.loads(event[6:]) + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert data["result"]["chunk"] == i + + def test_sync_generator_response(self, app): + """Test sync generator for streaming response.""" + + @app.entrypoint + def handler(request, context): + def generate(): + yield {"part": "A"} + yield {"part": "B"} + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + events = response.text.split("\n\n") + events = [e for e in events if e.strip()] + assert len(events) == 2 + + def test_streaming_dataclass_response_is_serialized(self, app): + """Test streaming payloads serialize A2A helper models.""" + + @app.entrypoint + def handler(request, context): + def generate(): + yield {"artifacts": [A2AArtifact.from_text("art-1", "response", "Hello")]} + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/stream", + }, + ) + + assert response.status_code == 200 + events = response.text.split("\n\n") + events = [e for e in events if e.strip()] + assert len(events) == 1 + + data = json.loads(events[0][6:]) + assert data["result"]["artifacts"][0]["artifactId"] == "art-1" + assert data["result"]["artifacts"][0]["parts"][0]["text"] == "Hello" + + +class TestSessionHeader: + def test_session_id_from_header(self, app): + """Test session ID is extracted from header.""" + captured_session_id = None + + @app.entrypoint + def handler(request, context): + nonlocal captured_session_id + captured_session_id = context.session_id + return {"session": context.session_id} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + headers={"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "test-session-123"}, + ) + + assert response.status_code == 200 + assert captured_session_id == "test-session-123" + + +class TestRunMethod: + @patch("uvicorn.run") + def test_run_default_port(self, mock_uvicorn, app): + """Test run uses default A2A port 9000.""" + app.run() + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["port"] == 9000 + assert call_kwargs["host"] == "127.0.0.1" + + @patch("uvicorn.run") + def test_run_custom_port(self, mock_uvicorn, app): + """Test run with custom port.""" + app.run(port=8080) + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["port"] == 8080 + + @patch.dict(os.environ, {"DOCKER_CONTAINER": "true"}) + @patch("uvicorn.run") + def test_run_in_docker(self, mock_uvicorn, agent_card): + """Test run in Docker environment.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.run() + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["host"] == "0.0.0.0" + + +class TestLifespan: + def test_lifespan_startup_and_shutdown(self, agent_card): + """Test lifespan startup and shutdown.""" + startup_called = False + shutdown_called = False + + @contextlib.asynccontextmanager + async def lifespan(app): + nonlocal startup_called, shutdown_called + startup_called = True + yield + shutdown_called = True + + app = BedrockAgentCoreA2AApp(agent_card=agent_card, lifespan=lifespan) + + with TestClient(app): + assert startup_called is True + assert shutdown_called is True + + +class TestIntegrationScenario: + def test_full_message_flow(self, app): + """Test complete message flow with A2A protocol.""" + + @app.entrypoint + def handler(request, context): + # Extract message from params + params = request.params or {} + message = params.get("message", {}) + parts = message.get("parts", []) + user_text = "" + for part in parts: + if part.get("kind") == "text": + user_text = part.get("text", "") + break + + # Return A2A formatted response + return { + "artifacts": [ + { + "artifactId": str(uuid.uuid4()), + "name": "agent_response", + "parts": [{"kind": "text", "text": f"Received: {user_text}"}], + } + ] + } + + client = TestClient(app) + + # Send message + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "What is 2 + 2?"}], + "messageId": "msg-001", + } + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert "result" in data + assert "artifacts" in data["result"] + assert len(data["result"]["artifacts"]) == 1 + assert "Received: What is 2 + 2?" in data["result"]["artifacts"][0]["parts"][0]["text"] + + +class TestA2ARequestContextFormatter: + def test_format_basic_record(self): + """Test basic log record formatting.""" + formatter = A2ARequestContextFormatter() + record = logging.LogRecord( + name="test", level=logging.INFO, pathname="test.py", + lineno=1, msg="Test message", args=(), exc_info=None, + ) + result = json.loads(formatter.format(record)) + assert result["level"] == "INFO" + assert result["message"] == "Test message" + assert result["protocol"] == "A2A" + + def test_format_with_exc_info(self): + """Test log record formatting with exception info.""" + formatter = A2ARequestContextFormatter() + try: + raise ValueError("test error") + except ValueError: + import sys + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", level=logging.ERROR, pathname="test.py", + lineno=1, msg="Error occurred", args=(), exc_info=exc_info, + ) + result = json.loads(formatter.format(record)) + assert result["errorType"] == "ValueError" + assert result["errorMessage"] == "test error" + assert "stackTrace" in result + assert "location" in result + + def test_format_with_request_and_session_ids(self, app): + """Test log record includes request and session IDs from context.""" + from bedrock_agentcore.runtime.context import BedrockAgentCoreContext + + BedrockAgentCoreContext.set_request_context("req-123", "sess-456") + formatter = A2ARequestContextFormatter() + record = logging.LogRecord( + name="test", level=logging.INFO, pathname="test.py", + lineno=1, msg="Test", args=(), exc_info=None, + ) + result = json.loads(formatter.format(record)) + assert result["requestId"] == "req-123" + assert result["sessionId"] == "sess-456" + + +class TestPingStatusAdvanced: + def test_ping_with_forced_status(self, agent_card): + """Test forced ping status overrides everything.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app._forced_ping_status = PingStatus.HEALTHY_BUSY + + status = app.get_current_ping_status() + assert status == PingStatus.HEALTHY_BUSY + + def test_ping_with_active_tasks(self, agent_card): + """Test automatic HEALTHY_BUSY when tasks are active.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app._active_tasks = {1: {"name": "task1"}} + + status = app.get_current_ping_status() + assert status == PingStatus.HEALTHY_BUSY + + def test_ping_custom_handler_returns_ping_status(self, agent_card): + """Test custom ping handler returning PingStatus enum directly.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.ping + def custom_ping(): + return PingStatus.HEALTHY_BUSY + + status = app.get_current_ping_status() + assert status == PingStatus.HEALTHY_BUSY + + def test_ping_custom_handler_returns_string(self, agent_card): + """Test custom ping handler returning string value.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.ping + def custom_ping(): + return "HealthyBusy" + + status = app.get_current_ping_status() + assert status == PingStatus.HEALTHY_BUSY + + def test_ping_custom_handler_exception_falls_back(self, agent_card): + """Test custom ping handler exception falls back to automatic.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.ping + def broken_ping(): + raise RuntimeError("ping failed") + + status = app.get_current_ping_status() + assert status == PingStatus.HEALTHY + + def test_ping_status_unchanged_does_not_update_timestamp(self, agent_card): + """Test timestamp is not updated when status doesn't change.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.get_current_ping_status() + first_time = app._last_status_update_time + + app.get_current_ping_status() + assert app._last_status_update_time == first_time + + def test_ping_endpoint_exception(self, agent_card): + """Test ping endpoint handles exception gracefully.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + # Force get_current_ping_status to throw + original = app.get_current_ping_status + def broken(): + raise RuntimeError("broken") + app.get_current_ping_status = broken + + client = TestClient(app) + response = client.get("/ping") + assert response.status_code == 200 + data = response.json() + assert data["status"] in ["Healthy", "HEALTHY"] + + +class TestBuildRequestContext: + def test_context_with_all_headers(self, app): + """Test context building with access token, oauth, auth, and custom headers.""" + captured_context = None + + @app.entrypoint + def handler(request, context): + nonlocal captured_context + captured_context = context + return {"ok": True} + + client = TestClient(app) + client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "test"}, + headers={ + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "sess-1", + "X-Amzn-Bedrock-AgentCore-Runtime-Request-Id": "req-1", + "WorkloadAccessToken": "token-abc", + "OAuth2CallbackUrl": "https://callback.example.com", + "Authorization": "Bearer xyz", + "X-Amzn-Bedrock-AgentCore-Runtime-Custom-MyHeader": "custom-val", + }, + ) + + assert captured_context is not None + assert captured_context.session_id == "sess-1" + + def test_context_build_exception_fallback(self, agent_card): + """Test context building falls back gracefully on exception.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.entrypoint + def handler(request, context): + return {"session": context.session_id} + + # Patch set_request_context: first call raises, second (in except) succeeds + call_count = 0 + original_set = __import__( + "bedrock_agentcore.runtime.context", fromlist=["BedrockAgentCoreContext"] + ).BedrockAgentCoreContext.set_request_context + + def failing_then_ok(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("context failure") + return original_set(*args, **kwargs) + + with patch( + "bedrock_agentcore.runtime.base_app.BedrockAgentCoreContext.set_request_context", + side_effect=failing_then_ok, + ): + client = TestClient(app) + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "test"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["result"]["session"] is None + + +class TestTakesContext: + def test_takes_context_with_exception(self, app): + """Test _takes_context returns False on exception.""" + # Use a mock that raises when inspecting signature + mock_handler = MagicMock(spec=[]) + mock_handler.__name__ = "mock" + with patch("inspect.signature", side_effect=ValueError("bad")): + assert app._takes_context(mock_handler) is False + + +class TestStreamingErrorHandling: + def test_async_streaming_error(self, app): + """Test error during async streaming yields error SSE event.""" + + @app.entrypoint + async def handler(request, context): + async def generate(): + yield {"chunk": 1} + raise ValueError("stream error") + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "req-001", "method": "message/stream"}, + ) + + assert response.status_code == 200 + events = [e for e in response.text.split("\n\n") if e.strip()] + assert len(events) >= 2 + # Last event should be an error + last_data = json.loads(events[-1].replace("data: ", "")) + assert "error" in last_data + + def test_sync_streaming_error(self, app): + """Test error during sync streaming yields error SSE event.""" + + @app.entrypoint + def handler(request, context): + def generate(): + yield {"chunk": 1} + raise ValueError("sync stream error") + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "req-001", "method": "message/stream"}, + ) + + assert response.status_code == 200 + events = [e for e in response.text.split("\n\n") if e.strip()] + assert len(events) >= 2 + last_data = json.loads(events[-1].replace("data: ", "")) + assert "error" in last_data + + +class TestConvertToSerializable: + def test_set_converted_to_list(self, app): + """Test set objects are converted to lists.""" + result = app._convert_to_serializable({"tags", "a2a"}) + assert isinstance(result, list) + assert set(result) == {"tags", "a2a"} + + def test_tuple_converted(self, app): + """Test tuple objects are converted to lists.""" + result = app._convert_to_serializable((1, 2, 3)) + assert result == [1, 2, 3] + + def test_object_with_to_dict(self, app): + """Test objects with to_dict method are serialized.""" + + @dataclass + class FakeModel: + name: str + def to_dict(self): + return {"name": self.name} + + result = app._convert_to_serializable(FakeModel(name="test")) + assert result == {"name": "test"} + + def test_nested_dict_serialization(self, app): + """Test nested dict with mixed types.""" + result = app._convert_to_serializable({ + "items": [A2AArtifact.from_text("1", "resp", "hello")], + "tags": {"a", "b"}, + }) + assert isinstance(result["items"], list) + assert result["items"][0]["artifactId"] == "1" + assert isinstance(result["tags"], list) + + +class TestSafeSerialize: + def test_normal_json(self, app): + """Test normal JSON serialization.""" + result = app._safe_serialize_to_json_string({"key": "value"}) + assert json.loads(result) == {"key": "value"} + + def test_non_serializable_fallback(self, app): + """Test non-serializable object uses convert then str fallback.""" + result = app._safe_serialize_to_json_string({"data": {1, 2, 3}}) + parsed = json.loads(result) + assert isinstance(parsed["data"], list) + + def test_totally_unserializable_uses_str(self, app): + """Test totally unserializable object falls back to str().""" + + class Unserializable: + def __repr__(self): + return "" + + # Patch _convert_to_serializable to also fail + original = app._convert_to_serializable + def broken(obj): + raise TypeError("cannot convert") + app._convert_to_serializable = broken + + result = app._safe_serialize_to_json_string(Unserializable()) + assert "" in result + + app._convert_to_serializable = original + + +class TestAgentCardEndpointError: + def test_agent_card_exception(self, agent_card): + """Test agent card endpoint handles exception.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + # Make to_dict raise + original = agent_card.to_dict + agent_card.to_dict = MagicMock(side_effect=RuntimeError("card error")) + + client = TestClient(app) + response = client.get("/.well-known/agent-card.json") + assert response.status_code == 500 + assert "error" in response.json() + + agent_card.to_dict = original + + +class TestInvokeHandlerException: + def test_sync_handler_exception_propagates(self, app): + """Test exception from sync handler propagates through _invoke_handler.""" + + @app.entrypoint + def handler(request, context): + raise RuntimeError("handler failed") + + client = TestClient(app) + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "test"}, + ) + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 + + def test_async_handler_exception_propagates(self, app): + """Test exception from async handler propagates through _invoke_handler.""" + + @app.entrypoint + async def handler(request, context): + raise RuntimeError("async handler failed") + + client = TestClient(app) + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "test"}, + ) + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 + + +class TestRunDockerDetection: + @patch("os.path.exists", return_value=True) + @patch("uvicorn.run") + def test_run_detects_dockerenv_file(self, mock_uvicorn, mock_exists, agent_card): + """Test run detects /.dockerenv file for Docker environment.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.run() + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["host"] == "0.0.0.0" + + @patch("uvicorn.run") + def test_run_with_custom_host(self, mock_uvicorn, agent_card): + """Test run with explicitly specified host.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.run(host="192.168.1.1") + + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["host"] == "192.168.1.1" + + @patch("uvicorn.run") + def test_run_passes_extra_kwargs(self, mock_uvicorn, agent_card): + """Test run passes extra kwargs to uvicorn.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.run(workers=4) + + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["workers"] == 4 + + +class TestGetRuntimeUrl: + def test_runtime_url_from_env(self, app): + """Test runtime URL from environment variable.""" + with patch.dict(os.environ, {"AGENTCORE_RUNTIME_URL": "https://runtime.example.com"}): + url = app._get_runtime_url() + assert url == "https://runtime.example.com" + + def test_runtime_url_from_request_base_url(self, app): + """Test runtime URL fallback to request.base_url.""" + mock_request = MagicMock() + mock_request.base_url = "http://localhost:9000/" + with patch.dict(os.environ, {}, clear=True): + # Remove AGENTCORE_RUNTIME_URL if set + os.environ.pop("AGENTCORE_RUNTIME_URL", None) + url = app._get_runtime_url(request=mock_request) + assert url == "http://localhost:9000/" + + def test_runtime_url_none_when_nothing_available(self, app): + """Test runtime URL returns None when nothing is available.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AGENTCORE_RUNTIME_URL", None) + url = app._get_runtime_url() + assert url is None + + +class TestAsyncTaskDecorator: + def test_async_task_decorator(self, agent_card): + """Test @async_task decorator tracks task and returns result.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.async_task + async def my_task(): + return "done" + + assert my_task.__name__ == "my_task" + result = asyncio.get_event_loop().run_until_complete(my_task()) + assert result == "done" + assert len(app._active_tasks) == 0 # task completed + + def test_async_task_rejects_sync(self, agent_card): + """Test @async_task raises ValueError for sync functions.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + with pytest.raises(ValueError, match="async"): + @app.async_task + def sync_func(): + pass + + def test_async_task_exception_cleanup(self, agent_card): + """Test @async_task cleans up task on exception.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.async_task + async def failing_task(): + raise RuntimeError("fail") + + with pytest.raises(RuntimeError, match="fail"): + asyncio.get_event_loop().run_until_complete(failing_task()) + assert len(app._active_tasks) == 0 + + +class TestAsyncTaskManagement: + def test_add_and_complete_task(self, agent_card): + """Test add_async_task and complete_async_task.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + task_id = app.add_async_task("test-task", metadata={"key": "val"}) + assert len(app._active_tasks) == 1 + assert app._active_tasks[task_id]["name"] == "test-task" + assert app._active_tasks[task_id]["metadata"] == {"key": "val"} + + result = app.complete_async_task(task_id) + assert result is True + assert len(app._active_tasks) == 0 + + def test_complete_unknown_task(self, agent_card): + """Test completing a non-existent task returns False.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + result = app.complete_async_task(99999) + assert result is False + + def test_get_async_task_info(self, agent_card): + """Test get_async_task_info returns correct data.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.add_async_task("job-1") + app.add_async_task("job-2") + + info = app.get_async_task_info() + assert info["active_count"] == 2 + assert len(info["running_jobs"]) == 2 + names = {j["name"] for j in info["running_jobs"]} + assert names == {"job-1", "job-2"} + + def test_force_and_clear_ping_status(self, agent_card): + """Test force_ping_status and clear_forced_ping_status.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.force_ping_status(PingStatus.HEALTHY_BUSY) + assert app.get_current_ping_status() == PingStatus.HEALTHY_BUSY + + app.clear_forced_ping_status() + assert app.get_current_ping_status() == PingStatus.HEALTHY + + +class TestContextVarsPropagation: + def test_sync_handler_preserves_context_vars(self, agent_card): + """Test sync handler receives context variables via contextvars.copy_context().""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + captured_request_id = None + + @app.entrypoint + def handler(request, context): + nonlocal captured_request_id + from bedrock_agentcore.runtime.context import BedrockAgentCoreContext + captured_request_id = BedrockAgentCoreContext.get_request_id() + return {"ok": True} + + client = TestClient(app) + client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "test"}, + headers={"X-Amzn-Bedrock-AgentCore-Runtime-Request-Id": "ctx-req-123"}, + ) + + assert captured_request_id == "ctx-req-123" diff --git a/tests/bedrock_agentcore/runtime/test_a2a_models.py b/tests/bedrock_agentcore/runtime/test_a2a_models.py new file mode 100644 index 00000000..4c59cd6b --- /dev/null +++ b/tests/bedrock_agentcore/runtime/test_a2a_models.py @@ -0,0 +1,463 @@ +"""Tests for A2A models.""" + +from bedrock_agentcore.runtime.a2a_models import ( + A2A_DEFAULT_PORT, + A2A_METHOD_MESSAGE_SEND, + A2AArtifact, + A2AMessage, + A2AMessagePart, + AgentCard, + AgentSkill, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + build_runtime_url, +) + + +class TestAgentSkill: + def test_basic_creation(self): + """Test creating a basic AgentSkill.""" + skill = AgentSkill( + id="calc", + name="Calculator", + description="Perform arithmetic calculations", + ) + assert skill.id == "calc" + assert skill.name == "Calculator" + assert skill.description == "Perform arithmetic calculations" + assert skill.tags == [] + + def test_creation_with_tags(self): + """Test creating AgentSkill with tags.""" + skill = AgentSkill( + id="search", + name="Web Search", + description="Search the web", + tags=["search", "web", "information"], + ) + assert skill.tags == ["search", "web", "information"] + + def test_to_dict(self): + """Test AgentSkill serialization to dict.""" + skill = AgentSkill( + id="calc", + name="Calculator", + description="Math operations", + tags=["math"], + ) + result = skill.to_dict() + assert result == { + "id": "calc", + "name": "Calculator", + "description": "Math operations", + "tags": ["math"], + } + + +class TestAgentCard: + def test_basic_creation(self): + """Test creating a basic AgentCard.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + assert card.name == "Test Agent" + assert card.description == "A test agent" + assert card.version == "1.0.0" + assert card.protocol_version == "0.3.0" + assert card.preferred_transport == "JSONRPC" + assert card.capabilities == {"streaming": True} + assert card.default_input_modes == ["text"] + assert card.default_output_modes == ["text"] + assert card.skills == [] + + def test_creation_with_skills(self): + """Test AgentCard with skills.""" + skills = [ + AgentSkill(id="s1", name="Skill 1", description="First skill"), + AgentSkill(id="s2", name="Skill 2", description="Second skill"), + ] + card = AgentCard( + name="Multi-Skill Agent", + description="An agent with multiple skills", + skills=skills, + ) + assert len(card.skills) == 2 + assert card.skills[0].id == "s1" + assert card.skills[1].id == "s2" + + def test_to_dict_without_url(self): + """Test AgentCard serialization without URL.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + result = card.to_dict() + assert result["name"] == "Test Agent" + assert result["description"] == "A test agent" + assert result["protocolVersion"] == "0.3.0" + assert result["preferredTransport"] == "JSONRPC" + assert "url" not in result + + def test_to_dict_with_url(self): + """Test AgentCard serialization with URL.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + result = card.to_dict(url="https://example.com/agent") + assert result["url"] == "https://example.com/agent" + + def test_to_dict_with_skills(self): + """Test AgentCard serialization with skills.""" + skills = [AgentSkill(id="s1", name="Skill 1", description="First skill")] + card = AgentCard( + name="Test Agent", + description="A test agent", + skills=skills, + ) + result = card.to_dict() + assert len(result["skills"]) == 1 + assert result["skills"][0]["id"] == "s1" + + +class TestJsonRpcRequest: + def test_from_dict(self): + """Test creating JsonRpcRequest from dict.""" + data = { + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": {"message": {"text": "Hello"}}, + } + request = JsonRpcRequest.from_dict(data) + assert request.jsonrpc == "2.0" + assert request.id == "req-001" + assert request.method == "message/send" + assert request.params == {"message": {"text": "Hello"}} + + def test_from_dict_minimal(self): + """Test creating JsonRpcRequest with minimal data.""" + data = {"method": "test"} + request = JsonRpcRequest.from_dict(data) + assert request.jsonrpc == "2.0" + assert request.id is None + assert request.method == "test" + assert request.params is None + + +class TestJsonRpcResponse: + def test_success_response(self): + """Test creating a success response.""" + response = JsonRpcResponse.success("req-001", {"result": "success"}) + assert response.id == "req-001" + assert response.result == {"result": "success"} + assert response.error is None + + def test_error_response(self): + """Test creating an error response.""" + response = JsonRpcResponse.error_response( + "req-001", + JsonRpcErrorCode.INTERNAL_ERROR, + "Something went wrong", + ) + assert response.id == "req-001" + assert response.result is None + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INTERNAL_ERROR + assert response.error.message == "Something went wrong" + + def test_success_to_dict(self): + """Test success response serialization.""" + response = JsonRpcResponse.success("req-001", {"data": "test"}) + result = response.to_dict() + assert result == { + "jsonrpc": "2.0", + "id": "req-001", + "result": {"data": "test"}, + } + + def test_error_to_dict(self): + """Test error response serialization.""" + response = JsonRpcResponse.error_response( + "req-001", + -32600, + "Invalid request", + ) + result = response.to_dict() + assert result == { + "jsonrpc": "2.0", + "id": "req-001", + "error": { + "code": -32600, + "message": "Invalid request", + }, + } + + +class TestA2AMessagePart: + def test_text_part(self): + """Test creating a text message part.""" + part = A2AMessagePart(kind="text", text="Hello, world!") + assert part.kind == "text" + assert part.text == "Hello, world!" + assert part.file is None + assert part.data is None + + def test_to_dict(self): + """Test message part serialization.""" + part = A2AMessagePart(kind="text", text="Test message") + result = part.to_dict() + assert result == {"kind": "text", "text": "Test message"} + + def test_from_dict(self): + """Test creating message part from dict.""" + data = {"kind": "text", "text": "Hello"} + part = A2AMessagePart.from_dict(data) + assert part.kind == "text" + assert part.text == "Hello" + + +class TestA2AMessage: + def test_basic_creation(self): + """Test creating a basic A2A message.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts) + assert message.role == "user" + assert len(message.parts) == 1 + assert message.message_id is None + + def test_with_message_id(self): + """Test message with ID.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts, message_id="msg-001") + assert message.message_id == "msg-001" + + def test_get_text(self): + """Test extracting text from message.""" + parts = [ + A2AMessagePart(kind="text", text="Line 1"), + A2AMessagePart(kind="text", text="Line 2"), + ] + message = A2AMessage(role="user", parts=parts) + assert message.get_text() == "Line 1\nLine 2" + + def test_to_dict(self): + """Test message serialization.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts, message_id="msg-001") + result = message.to_dict() + assert result == { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "msg-001", + } + + def test_from_dict(self): + """Test creating message from dict.""" + data = { + "role": "agent", + "parts": [{"kind": "text", "text": "Response"}], + "messageId": "msg-002", + } + message = A2AMessage.from_dict(data) + assert message.role == "agent" + assert len(message.parts) == 1 + assert message.parts[0].text == "Response" + assert message.message_id == "msg-002" + + +class TestA2AArtifact: + def test_basic_creation(self): + """Test creating a basic artifact.""" + parts = [A2AMessagePart(kind="text", text="Result")] + artifact = A2AArtifact( + artifact_id="art-001", + name="response", + parts=parts, + ) + assert artifact.artifact_id == "art-001" + assert artifact.name == "response" + assert len(artifact.parts) == 1 + + def test_from_text(self): + """Test creating text artifact.""" + artifact = A2AArtifact.from_text("art-001", "response", "Hello") + assert artifact.artifact_id == "art-001" + assert artifact.name == "response" + assert len(artifact.parts) == 1 + assert artifact.parts[0].kind == "text" + assert artifact.parts[0].text == "Hello" + + def test_to_dict(self): + """Test artifact serialization.""" + artifact = A2AArtifact.from_text("art-001", "response", "Result") + result = artifact.to_dict() + assert result == { + "artifactId": "art-001", + "name": "response", + "parts": [{"kind": "text", "text": "Result"}], + } + + +class TestBuildRuntimeUrl: + def test_basic_url(self): + """Test building runtime URL.""" + arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-agent" + url = build_runtime_url(arn) + # ARN should be URL-encoded + assert "us-west-2" in url + assert "arn%3Aaws%3Abedrock-agentcore" in url + assert url.startswith("https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/") + assert url.endswith("/invocations/") + + def test_region_parsed_from_arn(self): + """Test that region is automatically parsed from the ARN.""" + arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" + url = build_runtime_url(arn) + assert "us-east-1" in url + assert "bedrock-agentcore.us-east-1.amazonaws.com" in url + + def test_special_characters_encoded(self): + """Test that special characters in ARN are properly encoded.""" + arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/agent-with-special" + url = build_runtime_url(arn) + # Colon and slash should be encoded + assert "%3A" in url # Encoded colon + assert "%2F" in url # Encoded slash + + +class TestJsonRpcErrorWithData: + def test_error_to_dict_with_data(self): + """Test error serialization with data field.""" + from bedrock_agentcore.runtime.a2a_models import JsonRpcError + + error = JsonRpcError(code=-32600, message="Invalid", data={"detail": "missing field"}) + result = error.to_dict() + assert result == { + "code": -32600, + "message": "Invalid", + "data": {"detail": "missing field"}, + } + + def test_error_to_dict_without_data(self): + """Test error serialization without data field.""" + from bedrock_agentcore.runtime.a2a_models import JsonRpcError + + error = JsonRpcError(code=-32600, message="Invalid") + result = error.to_dict() + assert "data" not in result + + def test_error_response_with_data(self): + """Test creating error response with data.""" + response = JsonRpcResponse.error_response( + "req-001", -32600, "Invalid request", data={"details": "extra info"} + ) + result = response.to_dict() + assert result["error"]["data"] == {"details": "extra info"} + + +class TestA2AMessagePartAdvanced: + def test_to_dict_with_file(self): + """Test message part serialization with file field.""" + part = A2AMessagePart(kind="file", file={"uri": "s3://bucket/file.pdf", "mimeType": "application/pdf"}) + result = part.to_dict() + assert result == { + "kind": "file", + "file": {"uri": "s3://bucket/file.pdf", "mimeType": "application/pdf"}, + } + + def test_to_dict_with_data(self): + """Test message part serialization with data field.""" + part = A2AMessagePart(kind="data", data={"key": "value", "count": 42}) + result = part.to_dict() + assert result == { + "kind": "data", + "data": {"key": "value", "count": 42}, + } + + def test_to_dict_text_only(self): + """Test message part with only text (no file/data).""" + part = A2AMessagePart(kind="text", text="Hello") + result = part.to_dict() + assert "file" not in result + assert "data" not in result + + def test_from_dict_with_file_and_data(self): + """Test creating message part from dict with file and data.""" + data = { + "kind": "file", + "file": {"uri": "https://example.com/file.txt"}, + "data": {"meta": "info"}, + } + part = A2AMessagePart.from_dict(data) + assert part.kind == "file" + assert part.file == {"uri": "https://example.com/file.txt"} + assert part.data == {"meta": "info"} + + +class TestA2AMessageAdvanced: + def test_to_dict_without_message_id(self): + """Test message serialization without message_id.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts) + result = message.to_dict() + assert "messageId" not in result + + def test_get_text_skips_non_text_parts(self): + """Test get_text ignores non-text parts.""" + parts = [ + A2AMessagePart(kind="text", text="Hello"), + A2AMessagePart(kind="file", file={"uri": "s3://bucket/file"}), + A2AMessagePart(kind="text", text="World"), + ] + message = A2AMessage(role="user", parts=parts) + assert message.get_text() == "Hello\nWorld" + + def test_get_text_empty_parts(self): + """Test get_text with no parts.""" + message = A2AMessage(role="user", parts=[]) + assert message.get_text() == "" + + def test_get_text_skips_none_text(self): + """Test get_text skips parts with kind=text but text=None.""" + parts = [A2AMessagePart(kind="text", text=None)] + message = A2AMessage(role="user", parts=parts) + assert message.get_text() == "" + + def test_from_dict_without_message_id(self): + """Test creating message from dict without messageId.""" + data = {"role": "user", "parts": [{"kind": "text", "text": "Hi"}]} + message = A2AMessage.from_dict(data) + assert message.message_id is None + + +class TestConstants: + def test_default_port(self): + """Test A2A default port.""" + assert A2A_DEFAULT_PORT == 9000 + + def test_method_constants(self): + """Test A2A method constants.""" + assert A2A_METHOD_MESSAGE_SEND == "message/send" + + +class TestJsonRpcErrorCodes: + """Test that error code values match AWS Bedrock AgentCore documentation.""" + + def test_standard_jsonrpc_codes(self): + """Test standard JSON-RPC 2.0 error codes.""" + assert JsonRpcErrorCode.PARSE_ERROR == -32700 + assert JsonRpcErrorCode.INVALID_REQUEST == -32600 + assert JsonRpcErrorCode.METHOD_NOT_FOUND == -32601 + assert JsonRpcErrorCode.INVALID_PARAMS == -32602 + assert JsonRpcErrorCode.INTERNAL_ERROR == -32603 + + def test_agentcore_specific_codes(self): + """Test AgentCore-specific error codes per AWS documentation.""" + assert JsonRpcErrorCode.RESOURCE_NOT_FOUND == -32501 + assert JsonRpcErrorCode.VALIDATION_ERROR == -32052 + assert JsonRpcErrorCode.THROTTLING == -32053 + assert JsonRpcErrorCode.RESOURCE_CONFLICT == -32054 + assert JsonRpcErrorCode.RUNTIME_CLIENT_ERROR == -32055 diff --git a/tests_integ/runtime/a2a_client.py b/tests_integ/runtime/a2a_client.py new file mode 100644 index 00000000..ffcd0e72 --- /dev/null +++ b/tests_integ/runtime/a2a_client.py @@ -0,0 +1,75 @@ +import json +import logging + +import requests + + +class A2AClient: + """Local A2A client for invoking JSON-RPC endpoints.""" + + def __init__(self, endpoint: str): + self.endpoint = endpoint + self.logger = logging.getLogger("sdk-runtime-test-a2a-client") + + def get_agent_card(self): + """GET /.well-known/agent-card.json""" + url = f"{self.endpoint}/.well-known/agent-card.json" + self.logger.info("Fetching Agent Card from %s", url) + return requests.get(url, timeout=5).json() + + def ping(self): + """GET /ping""" + url = f"{self.endpoint}/ping" + self.logger.info("Pinging A2A server") + return requests.get(url, timeout=5).json() + + def send_message(self, text: str, request_id: str = "req-001", session_id: str = None): + """POST / with JSON-RPC message/send""" + url = self.endpoint + "/" + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] = session_id + + body = { + "jsonrpc": "2.0", + "id": request_id, + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": text}], + } + }, + } + self.logger.info("Sending message: %s", text) + resp = requests.post(url, headers=headers, json=body, timeout=30) + return resp.json() + + def stream_message(self, text: str, request_id: str = "req-001"): + """POST / with JSON-RPC message/stream (SSE response)""" + url = self.endpoint + "/" + body = { + "jsonrpc": "2.0", + "id": request_id, + "method": "message/stream", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": text}], + } + }, + } + self.logger.info("Streaming message: %s", text) + resp = requests.post(url, headers={"Content-Type": "application/json"}, json=body, timeout=30, stream=True) + + events = [] + for line in resp.iter_lines(decode_unicode=True): + if line and line.startswith("data: "): + events.append(json.loads(line[6:])) + return events + + def send_raw(self, body: dict): + """POST / with raw JSON-RPC body""" + url = self.endpoint + "/" + resp = requests.post(url, headers={"Content-Type": "application/json"}, json=body, timeout=10) + return resp.json() diff --git a/tests_integ/runtime/test_a2a_agent.py b/tests_integ/runtime/test_a2a_agent.py new file mode 100644 index 00000000..12efae7d --- /dev/null +++ b/tests_integ/runtime/test_a2a_agent.py @@ -0,0 +1,242 @@ +"""E2E integration test for A2A protocol support. + +Tests the full A2A lifecycle: +- Agent Card discovery +- Ping health check +- JSON-RPC message/send (sync) +- JSON-RPC message/stream (SSE streaming) +- Error handling (invalid JSON-RPC, missing method, etc.) +- Session ID propagation +""" + +import logging +import textwrap + +from tests_integ.runtime.a2a_client import A2AClient +from tests_integ.runtime.base_test import BaseSDKRuntimeTest, start_agent_server + +logger = logging.getLogger("sdk-runtime-a2a-agent-test") + +A2A_SERVER_ENDPOINT = "http://127.0.0.1:9000" + + +class TestSDKA2AAgent(BaseSDKRuntimeTest): + def setup(self): + self.agent_module = "a2a_agent" + with open(self.agent_module + ".py", "w") as file: + content = textwrap.dedent(""" + import uuid + from bedrock_agentcore.runtime import ( + BedrockAgentCoreA2AApp, + AgentCard, + AgentSkill, + ) + + agent_card = AgentCard( + name="E2E Test Agent", + description="An agent for E2E integration testing", + skills=[ + AgentSkill( + id="echo", + name="Echo", + description="Echoes back user input", + tags=["test", "echo"], + ), + ], + ) + + app = BedrockAgentCoreA2AApp(agent_card=agent_card, debug=True) + + @app.entrypoint + def handle_message(request, context): + params = request.params or {} + message = params.get("message", {}) + parts = message.get("parts", []) + + user_text = "" + for part in parts: + if part.get("kind") == "text": + user_text = part.get("text", "") + break + + session_id = context.session_id if context else None + + return { + "artifacts": [ + { + "artifactId": str(uuid.uuid4()), + "name": "echo_response", + "parts": [ + {"kind": "text", "text": f"Echo: {user_text}"}, + {"kind": "data", "data": {"session_id": session_id}}, + ], + } + ] + } + + app.run() + """).strip() + file.write(content) + + def run_test(self): + with start_a2a_server(self.agent_module): + client = A2AClient(A2A_SERVER_ENDPOINT) + + self._test_agent_card(client) + self._test_ping(client) + self._test_message_send(client) + self._test_session_propagation(client) + self._test_invalid_jsonrpc(client) + self._test_missing_method(client) + + logger.info("All A2A E2E tests passed!") + + def _test_agent_card(self, client): + """Test Agent Card discovery endpoint.""" + logger.info("--- Testing Agent Card ---") + card = client.get_agent_card() + assert card["name"] == "E2E Test Agent", f"Expected 'E2E Test Agent', got {card['name']}" + assert card["description"] == "An agent for E2E integration testing" + assert card["protocolVersion"] == "0.3.0" + assert card["preferredTransport"] == "JSONRPC" + assert card["capabilities"]["streaming"] is True + assert len(card["skills"]) == 1 + assert card["skills"][0]["id"] == "echo" + assert card["skills"][0]["tags"] == ["test", "echo"] + assert "url" in card + logger.info("Agent Card test passed: %s", card["name"]) + + def _test_ping(self, client): + """Test ping health check endpoint.""" + logger.info("--- Testing Ping ---") + ping = client.ping() + assert ping["status"] in ["Healthy", "HEALTHY"], f"Unexpected status: {ping['status']}" + assert "time_of_last_update" in ping + logger.info("Ping test passed: %s", ping) + + def _test_message_send(self, client): + """Test JSON-RPC message/send.""" + logger.info("--- Testing message/send ---") + response = client.send_message("Hello, A2A!", request_id="test-001") + + assert response["jsonrpc"] == "2.0" + assert response["id"] == "test-001" + assert "result" in response + assert "error" not in response + + result = response["result"] + assert "artifacts" in result + assert len(result["artifacts"]) == 1 + + artifact = result["artifacts"][0] + assert artifact["name"] == "echo_response" + text_part = artifact["parts"][0] + assert text_part["text"] == "Echo: Hello, A2A!" + logger.info("message/send test passed: %s", text_part["text"]) + + def _test_session_propagation(self, client): + """Test session ID is propagated to handler.""" + logger.info("--- Testing Session Propagation ---") + response = client.send_message( + "session test", + request_id="test-session", + session_id="my-session-456", + ) + + result = response["result"] + data_part = result["artifacts"][0]["parts"][1] + assert data_part["data"]["session_id"] == "my-session-456", ( + f"Expected session_id 'my-session-456', got {data_part['data']['session_id']}" + ) + logger.info("Session propagation test passed") + + def _test_invalid_jsonrpc(self, client): + """Test error handling for invalid JSON-RPC version.""" + logger.info("--- Testing Invalid JSON-RPC ---") + response = client.send_raw({ + "jsonrpc": "1.0", + "id": "bad-001", + "method": "message/send", + }) + + assert "error" in response + assert response["error"]["code"] == -32600 # INVALID_REQUEST + logger.info("Invalid JSON-RPC test passed: code=%d", response["error"]["code"]) + + def _test_missing_method(self, client): + """Test error handling for missing method.""" + logger.info("--- Testing Missing Method ---") + response = client.send_raw({ + "jsonrpc": "2.0", + "id": "bad-002", + }) + + assert "error" in response + assert response["error"]["code"] == -32600 # INVALID_REQUEST + logger.info("Missing method test passed: code=%d", response["error"]["code"]) + + +import subprocess +import threading +import time +from contextlib import contextmanager + + +@contextmanager +def start_a2a_server(agent_module, timeout=10): + """Start an A2A agent server on port 9000.""" + logger.info("Starting A2A agent server...") + start_time = time.time() + + agent_server = subprocess.Popen( + ["python", "-m", agent_module], + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + try: + while time.time() - start_time < timeout: + if agent_server.stdout is None: + raise RuntimeError("Agent server has no configured output") + + if agent_server.poll() is not None: + out = agent_server.stdout.read() + raise RuntimeError(f"Error when running agent server: {out}") + + line = agent_server.stdout.readline() + while line: + line = line.strip() + if line: + logger.info(line) + if "Uvicorn running on" in line and "9000" in line: + # Start logging thread + def log_output(): + for l in iter(agent_server.stdout.readline, ""): + if l.strip(): + logger.info(l.strip()) + t = threading.Thread(target=log_output, daemon=True) + t.start() + yield agent_server + return + line = agent_server.stdout.readline() + + time.sleep(0.5) + raise TimeoutError(f"A2A server did not start within {timeout} seconds") + finally: + logger.info("Stopping A2A agent server...") + if agent_server.poll() is None: + agent_server.terminate() + try: + agent_server.wait(timeout=5) + except subprocess.TimeoutExpired: + agent_server.kill() + agent_server.wait() + finally: + if agent_server.stdout: + agent_server.stdout.close() + logger.info("A2A agent server terminated") + + +def test(tmp_path): + TestSDKA2AAgent().run(tmp_path)