diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index dd3562289..b3eed6474 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -52,8 +52,11 @@ def __init__( # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) + self._is_new_session = True session = Session(session_id=session_id, session_type=SessionType.AGENT) session_repository.create_session(session) + else: + self._is_new_session = False self.session = session @@ -124,8 +127,8 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: else: state_changed = current_state_version != last_synced.get("state_version") internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version") - conversation_manager_state_changed = ( - current_conversation_manager_state != last_synced.get("conversation_manager_state") + conversation_manager_state_changed = current_conversation_manager_state != last_synced.get( + "conversation_manager_state" ) if not state_changed and not internal_state_changed and not conversation_manager_state_changed: @@ -170,7 +173,11 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( @@ -299,7 +306,12 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non source: Multi-agent source object to restore state into **kwargs: Additional keyword arguments for future extensibility. """ - state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + # Skip read_multi_agent call for new sessions since no multi-agents can exist yet + if self._is_new_session: + state = None + else: + state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + if state is None: self.session_repository.create_multi_agent(self.session_id, source, **kwargs) else: @@ -317,7 +329,11 @@ def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index f8f044a9b..9b2d84a51 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -28,6 +28,15 @@ def session_manager(mock_repository): return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) +@pytest.fixture +def existing_session_manager(mock_repository): + """Create a session manager with a pre-existing session in the repository.""" + # Create session first so the manager sees it as existing + session = Session(session_id="test-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + @pytest.fixture def agent(): """Create a mock agent.""" @@ -100,7 +109,7 @@ def test_initialize_multiple_agents_without_id(session_manager, agent): session_manager.initialize(agent2) -def test_initialize_restores_existing_agent(session_manager, agent): +def test_initialize_restores_existing_agent(existing_session_manager, agent): """Test that initializing an existing agent restores its state.""" # Set agent ID agent.agent_id = "existing-agent" @@ -112,7 +121,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): conversation_manager_state=SlidingWindowConversationManager().get_state(), _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -122,10 +131,10 @@ def test_initialize_restores_existing_agent(session_manager, agent): }, message_id=0, ) - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -135,7 +144,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) -def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): +def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(existing_session_manager): """Test that initializing an existing agent restores its state.""" conversation_manager = SummarizingConversationManager() conversation_manager.removed_message_count = 1 @@ -147,7 +156,7 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -158,13 +167,13 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage message_id=0, ) # Create two messages as one will be removed by the conversation manager - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) message.message_id = 1 - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent agent = Agent(agent_id="existing-agent", conversation_manager=SummarizingConversationManager()) - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -217,26 +226,26 @@ def test_initialize_multi_agent_new(session_manager, mock_multi_agent): assert state["state"] == {"key": "value"} -def test_initialize_multi_agent_existing(session_manager, mock_multi_agent): +def test_initialize_multi_agent_existing(existing_session_manager, mock_multi_agent): """Test initializing existing multi-agent state.""" # Create existing state first - session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) + existing_session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) # Create a mock with updated state for the update call updated_mock = Mock() updated_mock.id = "test-multi-agent" existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} updated_mock.serialize_state.return_value = existing_state - session_manager.session_repository.update_multi_agent("test-session", updated_mock) + existing_session_manager.session_repository.update_multi_agent("test-session", updated_mock) # Initialize multi-agent - session_manager.initialize_multi_agent(mock_multi_agent) + existing_session_manager.initialize_multi_agent(mock_multi_agent) # Verify deserialize_state was called with existing state mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) -def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): +def test_fix_broken_tool_use_adds_missing_tool_results(existing_session_manager): """Test that _fix_broken_tool_use adds missing toolResult messages.""" conversation_manager = SlidingWindowConversationManager() @@ -246,7 +255,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -261,11 +270,13 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -277,7 +288,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): assert fixed_messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): +def test_fix_broken_tool_use_extends_partial_tool_results(existing_session_manager): """Test fixing messages where some toolResults are missing.""" conversation_manager = SlidingWindowConversationManager() # Create agent in repository first @@ -286,7 +297,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -309,11 +320,13 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -330,7 +343,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): +def test_fix_broken_tool_use_handles_multiple_orphaned_tools(existing_session_manager): """Test fixing multiple orphaned toolUse messages.""" conversation_manager = SlidingWindowConversationManager() @@ -340,7 +353,7 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -358,11 +371,13 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -449,7 +464,7 @@ def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): assert messages[0].message["role"] == "user" -def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): +def test_initialize_bidi_agent_restores_existing(existing_session_manager, mock_bidi_agent): """Test initializing BidiAgent restores from existing session.""" # Create existing session data session_agent = SessionAgent( @@ -457,16 +472,16 @@ def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agen state={"restored": "state"}, conversation_manager_state={}, # Empty for BidiAgent ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add messages msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "Message 1"}]}, 0) msg2 = SessionMessage.from_message({"role": "assistant", "content": [{"text": "Response 1"}]}, 1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify state restored assert mock_bidi_agent.state.get() == {"restored": "state"} @@ -532,7 +547,7 @@ def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): session_manager.initialize_bidi_agent(agent2) -def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): +def test_bidi_agent_messages_with_offset_zero(existing_session_manager, mock_bidi_agent): """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" # Create session with messages session_agent = SessionAgent( @@ -540,15 +555,15 @@ def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): state={}, conversation_manager_state={}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add 5 messages for i in range(5): msg = SessionMessage.from_message({"role": "user", "content": [{"text": f"Message {i}"}]}, i) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify all messages restored (offset=0, no removed_message_count) assert len(mock_bidi_agent.messages) == 5 @@ -811,3 +826,208 @@ def tracking_update_agent(session_id, session_agent): # First sync should always update (no previous state) session_manager.sync_agent(agent) assert len(update_agent_calls) == 1 + + +# ============================================================================ +# New Session Optimization Tests (Issue #1828) +# ============================================================================ + + +def test_is_new_session_true_when_session_created(mock_repository): + """Test that _is_new_session is True when creating a new session.""" + # Session doesn't exist yet + assert mock_repository.read_session("new-session") is None + + # Creating manager should set _is_new_session to True + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + + assert manager._is_new_session is True + + +def test_is_new_session_false_when_session_exists(mock_repository): + """Test that _is_new_session is False when using an existing session.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Creating manager should set _is_new_session to False + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + + assert manager._is_new_session is False + + +def test_initialize_skips_read_agent_for_new_session(mock_repository): + """Test that initialize() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "test-agent") + + +def test_initialize_bidi_agent_skips_read_agent_for_new_session(mock_repository): + """Test that initialize_bidi_agent() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_bidi_agent_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize_bidi_agent() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "bidi-agent-1") + + +def test_initialize_multi_agent_skips_read_for_new_session(mock_repository): + """Test that initialize_multi_agent() skips read_multi_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should NOT be called for new session + assert len(read_multi_agent_calls) == 0 + + +def test_initialize_multi_agent_calls_read_for_existing_session(mock_repository): + """Test that initialize_multi_agent() calls read_multi_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should be called for existing session + assert len(read_multi_agent_calls) == 1 + assert read_multi_agent_calls[0] == ("existing-session", "test-multi-agent")