From 57b0304bdc9b3889f0c5666ee93d16151d8839c8 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:16:04 -0500 Subject: [PATCH 1/4] add dapr session option Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- .gitignore | 3 +- pyproject.toml | 7 +- src/strands/session/__init__.py | 40 ++ src/strands/session/dapr_session_manager.py | 538 +++++++++++++++++ .../session/test_dapr_session_manager.py | 542 ++++++++++++++++++ tests_integ/test_dapr_session.py | 428 ++++++++++++++ 6 files changed, 1556 insertions(+), 2 deletions(-) create mode 100644 src/strands/session/dapr_session_manager.py create mode 100644 tests/strands/session/test_dapr_session_manager.py create mode 100644 tests_integ/test_dapr_session.py diff --git a/.gitignore b/.gitignore index 888a96bbc..5c9f496ae 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ __pycache__* .vscode dist repl_state -.kiro \ No newline at end of file +.kiro +.idea \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index af8e45ffc..060e2f5da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ sagemaker = [ "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] +dapr = ["dapr>=1.14.0", "grpcio>=1.60.0"] docs = [ "sphinx>=5.0.0,<9.0.0", "sphinx-rtd-theme>=1.0.0,<2.0.0", @@ -68,7 +69,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,dapr,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", @@ -81,6 +82,7 @@ dev = [ "pytest-asyncio>=1.0.0,<1.3.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.14.0", + "testcontainers[redis]>=4.0.0", ] [project.urls] @@ -134,6 +136,9 @@ dependencies = [ "pytest-asyncio>=1.0.0,<1.3.0", "pytest-xdist>=3.0.0,<4.0.0", "moto>=5.1.0,<6.0.0", + "dapr>=1.14.0", + "grpcio>=1.60.0", + "testcontainers[redis]>=4.0.0", ] [[tool.hatch.envs.hatch-test.matrix]] diff --git a/src/strands/session/__init__.py b/src/strands/session/__init__.py index 7b5310190..a8b58a717 100644 --- a/src/strands/session/__init__.py +++ b/src/strands/session/__init__.py @@ -3,6 +3,8 @@ This module provides session management functionality. """ +from typing import Any + from .file_session_manager import FileSessionManager from .repository_session_manager import RepositorySessionManager from .s3_session_manager import S3SessionManager @@ -10,9 +12,47 @@ from .session_repository import SessionRepository __all__ = [ + "DAPR_CONSISTENCY_EVENTUAL", + "DAPR_CONSISTENCY_STRONG", + "DaprSessionManager", "FileSessionManager", "RepositorySessionManager", "S3SessionManager", "SessionManager", "SessionRepository", ] + + +def __getattr__(name: str) -> Any: + """Lazy import for optional dependencies.""" + if name == "DaprSessionManager": + try: + from .dapr_session_manager import DaprSessionManager + + return DaprSessionManager + except ModuleNotFoundError as e: + raise ImportError( + "DaprSessionManager requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + ) from e + + if name == "DAPR_CONSISTENCY_EVENTUAL": + try: + from .dapr_session_manager import DAPR_CONSISTENCY_EVENTUAL + + return DAPR_CONSISTENCY_EVENTUAL + except ModuleNotFoundError as e: + raise ImportError( + "DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + ) from e + + if name == "DAPR_CONSISTENCY_STRONG": + try: + from .dapr_session_manager import DAPR_CONSISTENCY_STRONG + + return DAPR_CONSISTENCY_STRONG + except ModuleNotFoundError as e: + raise ImportError( + "DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + ) from e + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/strands/session/dapr_session_manager.py b/src/strands/session/dapr_session_manager.py new file mode 100644 index 000000000..18eef1e79 --- /dev/null +++ b/src/strands/session/dapr_session_manager.py @@ -0,0 +1,538 @@ +"""Dapr state store session manager for distributed storage.""" + +import json +import logging +from typing import Any, Dict, List, Literal, Optional, cast + +from dapr.clients import DaprClient +from dapr.clients.grpc._state import Consistency, StateOptions + +from .. import _identifier +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .repository_session_manager import RepositorySessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +# Type-safe consistency constants +ConsistencyLevel = Literal["eventual", "strong"] +DAPR_CONSISTENCY_EVENTUAL: ConsistencyLevel = "eventual" +DAPR_CONSISTENCY_STRONG: ConsistencyLevel = "strong" + + +class DaprSessionManager(RepositorySessionManager, SessionRepository): + """Dapr state store session manager for distributed storage. + + Stores session data in Dapr state stores (Redis, PostgreSQL, MongoDB, Cosmos DB, etc.) + with support for TTL and consistency levels. + + Key structure: + - `{session_id}:session` - Session metadata + - `{session_id}:agents:{agent_id}` - Agent metadata + - `{session_id}:messages:{agent_id}` - Message list (JSON array) + """ + + def __init__( + self, + session_id: str, + state_store_name: str, + dapr_client: DaprClient, + ttl: Optional[int] = None, + consistency: ConsistencyLevel = DAPR_CONSISTENCY_EVENTUAL, + **kwargs: Any, + ): + """Initialize DaprSessionManager. + + Args: + session_id: ID for the session. + ID is not allowed to contain path separators (e.g., a/b). + state_store_name: Name of the Dapr state store component. + dapr_client: DaprClient instance for state operations. + ttl: Optional time-to-live in seconds for state items. + consistency: Consistency level for state operations ("eventual" or "strong"). + **kwargs: Additional keyword arguments for future extensibility. + """ + self._state_store_name = state_store_name + self._dapr_client = dapr_client + self._ttl = ttl + self._consistency = consistency + self._owns_client = False + + super().__init__(session_id=session_id, session_repository=self) + + @classmethod + def from_address( + cls, + session_id: str, + state_store_name: str, + dapr_address: str = "localhost:50001", + **kwargs: Any, + ) -> "DaprSessionManager": + """Create DaprSessionManager from Dapr address. + + Args: + session_id: ID for the session. + state_store_name: Name of the Dapr state store component. + dapr_address: Dapr gRPC endpoint (default: localhost:50001). + **kwargs: Additional arguments passed to __init__ (ttl, consistency, etc.). + + Returns: + DaprSessionManager instance with owned client. + """ + dapr_client = DaprClient(address=dapr_address) + manager = cls(session_id, state_store_name=state_store_name, dapr_client=dapr_client, **kwargs) + manager._owns_client = True + return manager + + def _get_session_key(self, session_id: str) -> str: + """Get session state key. + + Args: + session_id: ID for the session. + + Returns: + State store key for the session. + + Raises: + ValueError: If session id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) + return f"{session_id}:session" + + def _get_agent_key(self, session_id: str, agent_id: str) -> str: + """Get agent state key. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Returns: + State store key for the agent. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) + return f"{session_id}:agents:{agent_id}" + + def _get_messages_key(self, session_id: str, agent_id: str) -> str: + """Get messages list state key. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Returns: + State store key for the messages list. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) + return f"{session_id}:messages:{agent_id}" + + def _get_read_metadata(self) -> Dict[str, str]: + """Get metadata for read operations (consistency). + + Returns: + Metadata dictionary for state reads. + """ + metadata: Dict[str, str] = {} + if self._consistency: + metadata["consistency"] = self._consistency + return metadata + + def _get_write_metadata(self) -> Dict[str, str]: + """Get metadata for write operations (TTL). + + Returns: + Metadata dictionary for state writes. + """ + metadata: Dict[str, str] = {} + if self._ttl is not None: + metadata["ttlInSeconds"] = str(self._ttl) + return metadata + + def _get_state_options(self) -> Optional[StateOptions]: + """Get state options for write/delete operations (consistency). + + Returns: + StateOptions for consistency or None. + """ + if self._consistency == DAPR_CONSISTENCY_STRONG: + return StateOptions(consistency=Consistency.strong) + elif self._consistency == DAPR_CONSISTENCY_EVENTUAL: + return StateOptions(consistency=Consistency.eventual) + return None + + def _read_state(self, key: str) -> Optional[Dict[str, Any]]: + """Read and parse JSON state from Dapr. + + Args: + key: State store key. + + Returns: + Parsed JSON dictionary or None if not found. + + Raises: + SessionException: If state is corrupted or read fails. + """ + try: + response = self._dapr_client.get_state( + store_name=self._state_store_name, + key=key, + state_metadata=self._get_read_metadata(), + ) + + if not response.data: + return None + + content = response.data.decode("utf-8") + return cast(Dict[str, Any], json.loads(content)) + + except json.JSONDecodeError as e: + raise SessionException(f"Invalid JSON in state key {key}: {e}") from e + except Exception as e: + raise SessionException(f"Failed to read state key {key}: {e}") from e + + def _write_state(self, key: str, data: Dict[str, Any]) -> None: + """Write JSON state to Dapr. + + Args: + key: State store key. + data: Dictionary to serialize and store. + + Raises: + SessionException: If write fails. + """ + try: + content = json.dumps(data, ensure_ascii=False) + self._dapr_client.save_state( + store_name=self._state_store_name, + key=key, + value=content, + state_metadata=self._get_write_metadata(), + options=self._get_state_options(), + ) + except Exception as e: + raise SessionException(f"Failed to write state key {key}: {e}") from e + + def _delete_state(self, key: str) -> None: + """Delete state from Dapr. + + Args: + key: State store key. + + Raises: + SessionException: If delete fails. + """ + try: + self._dapr_client.delete_state( + store_name=self._state_store_name, + key=key, + state_metadata=self._get_read_metadata(), + options=self._get_state_options(), + ) + except Exception as e: + raise SessionException(f"Failed to delete state key {key}: {e}") from e + + def _get_manifest_key(self, session_id: str) -> str: + """Get session manifest key (tracks agent_ids for deletion).""" + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) + return f"{session_id}:manifest" + + def create_session(self, session: Session, **kwargs: Any) -> Session: + """Create a new session. + + Args: + session: Session to create. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + Created session. + + Raises: + SessionException: If session already exists or creation fails. + """ + session_key = self._get_session_key(session.session_id) + + # Check if session already exists + existing = self.read_session(session.session_id) + if existing is not None: + raise SessionException(f"Session {session.session_id} already exists") + + # Write session data + session_dict = session.to_dict() + self._write_state(session_key, session_dict) + return session + + def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + """Read session data. + + Args: + session_id: ID of the session to read. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + Session if found, None otherwise. + + Raises: + SessionException: If read fails. + """ + session_key = self._get_session_key(session_id) + + session_data = self._read_state(session_key) + if session_data is None: + return None + + return Session.from_dict(session_data) + + def delete_session(self, session_id: str, **kwargs: Any) -> None: + """Delete session and all associated data. + + Uses a session manifest to discover agent IDs for cleanup. + """ + session_key = self._get_session_key(session_id) + manifest_key = self._get_manifest_key(session_id) + + # Read manifest (may be missing if no agents created) + manifest = self._read_state(manifest_key) + agent_ids: list[str] = manifest.get("agents", []) if manifest else [] + + # Delete agent and message keys + for agent_id in agent_ids: + agent_key = self._get_agent_key(session_id, agent_id) + messages_key = self._get_messages_key(session_id, agent_id) + self._delete_state(agent_key) + self._delete_state(messages_key) + + # Delete manifest and session + self._delete_state(manifest_key) + self._delete_state(session_key) + + def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Create a new agent in the session. + + Args: + session_id: ID of the session. + session_agent: Agent to create. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + SessionException: If creation fails. + """ + agent_key = self._get_agent_key(session_id, session_agent.agent_id) + agent_dict = session_agent.to_dict() + + self._write_state(agent_key, agent_dict) + + # Initialize empty messages list + messages_key = self._get_messages_key(session_id, session_agent.agent_id) + self._write_state(messages_key, {"messages": []}) + + # Update manifest with this agent + manifest_key = self._get_manifest_key(session_id) + manifest = self._read_state(manifest_key) or {"agents": []} + if session_agent.agent_id not in manifest["agents"]: + manifest["agents"].append(session_agent.agent_id) + self._write_state(manifest_key, manifest) + + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + """Read agent data. + + Args: + session_id: ID of the session. + agent_id: ID of the agent. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + SessionAgent if found, None otherwise. + + Raises: + SessionException: If read fails. + """ + agent_key = self._get_agent_key(session_id, agent_id) + + agent_data = self._read_state(agent_key) + if agent_data is None: + return None + + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Update agent data. + + Args: + session_id: ID of the session. + session_agent: Agent to update. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + SessionException: If agent doesn't exist or update fails. + """ + previous_agent = self.read_agent(session_id=session_id, agent_id=session_agent.agent_id) + if previous_agent is None: + raise SessionException(f"Agent {session_agent.agent_id} in session {session_id} does not exist") + + # Preserve creation timestamp + session_agent.created_at = previous_agent.created_at + + agent_key = self._get_agent_key(session_id, session_agent.agent_id) + + self._write_state(agent_key, session_agent.to_dict()) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Create a new message for the agent. + + Args: + session_id: ID of the session. + agent_id: ID of the agent. + session_message: Message to create. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + SessionException: If creation fails. + """ + messages_key = self._get_messages_key(session_id, agent_id) + + # Read existing messages + messages_data = self._read_state(messages_key) + if messages_data is None: + messages_list = [] + else: + messages_list = messages_data.get("messages", []) + if not isinstance(messages_list, list): + messages_list = [] + + # Append new message + messages_list.append(session_message.to_dict()) + + # Write back + self._write_state(messages_key, {"messages": messages_list}) + + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + """Read message data. + + Args: + session_id: ID of the session. + agent_id: ID of the agent. + message_id: Index of the message. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + SessionMessage if found, None otherwise. + + Raises: + ValueError: If message_id is not an integer. + SessionException: If read fails. + """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + + messages_key = self._get_messages_key(session_id, agent_id) + + messages_data = self._read_state(messages_key) + if messages_data is None: + return None + + messages_list = messages_data.get("messages", []) + if not isinstance(messages_list, list): + messages_list = [] + + # Find message by ID + for msg_dict in messages_list: + if msg_dict.get("message_id") == message_id: + return SessionMessage.from_dict(msg_dict) + + return None + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Update message data. + + Args: + session_id: ID of the session. + agent_id: ID of the agent. + session_message: Message to update. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + SessionException: If message doesn't exist or update fails. + """ + previous_message = self.read_message( + session_id=session_id, agent_id=agent_id, message_id=session_message.message_id + ) + if previous_message is None: + raise SessionException(f"Message {session_message.message_id} does not exist") + + # Preserve creation timestamp + session_message.created_at = previous_message.created_at + + messages_key = self._get_messages_key(session_id, agent_id) + + # Read existing messages + messages_data = self._read_state(messages_key) + if messages_data is None: + raise SessionException(f"Messages not found for agent {agent_id} in session {session_id}") + + messages_list = messages_data.get("messages", []) + if not isinstance(messages_list, list): + messages_list = [] + + # Find and update message + updated = False + for i, msg_dict in enumerate(messages_list): + if msg_dict.get("message_id") == session_message.message_id: + messages_list[i] = session_message.to_dict() + updated = True + break + + if not updated: + raise SessionException(f"Message {session_message.message_id} not found in list") + + # Write back + self._write_state(messages_key, {"messages": messages_list}) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + ) -> List[SessionMessage]: + """List messages for an agent with pagination. + + Args: + session_id: ID of the session. + agent_id: ID of the agent. + limit: Maximum number of messages to return. + offset: Number of messages to skip. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of SessionMessage objects. + + Raises: + SessionException: If read fails. + """ + messages_key = self._get_messages_key(session_id, agent_id) + + messages_data = self._read_state(messages_key) + if messages_data is None: + return [] + + messages_list = messages_data.get("messages", []) + if not isinstance(messages_list, list): + messages_list = [] + + # Apply pagination + if limit is not None: + messages_list = messages_list[offset : offset + limit] + else: + messages_list = messages_list[offset:] + + # Convert to SessionMessage objects + return [SessionMessage.from_dict(msg_dict) for msg_dict in messages_list] + + def close(self) -> None: + """Close the Dapr client if owned by this manager.""" + if self._owns_client: + self._dapr_client.close() diff --git a/tests/strands/session/test_dapr_session_manager.py b/tests/strands/session/test_dapr_session_manager.py new file mode 100644 index 000000000..2a216ed20 --- /dev/null +++ b/tests/strands/session/test_dapr_session_manager.py @@ -0,0 +1,542 @@ +"""Tests for DaprSessionManager.""" + +from typing import Any, Optional +from unittest.mock import Mock + +import pytest +from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.session.dapr_session_manager import ( + DAPR_CONSISTENCY_EVENTUAL, + DAPR_CONSISTENCY_STRONG, + DaprSessionManager, +) +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +class FakeDaprClient: + """Sync fake Dapr client for testing.""" + + def __init__(self) -> None: + """Initialize fake client with in-memory state.""" + self._state: dict[str, bytes] = {} + self._closed = False + + def get_state( + self, + store_name: str, + key: str, + state_metadata: Optional[dict[str, str]] = None, + **kwargs: Any, + ) -> Mock: + """Get state from in-memory store.""" + response = Mock() + response.data = self._state.get(key) + return response + + def save_state( + self, + store_name: str, + key: str, + value: str | bytes, + state_metadata: Optional[dict[str, str]] = None, + options: Any = None, + ) -> None: + """Save state to in-memory store.""" + if isinstance(value, str): + self._state[key] = value.encode("utf-8") + else: + self._state[key] = value + + def delete_state( + self, + store_name: str, + key: str, + state_metadata: Optional[dict[str, str]] = None, + options: Any = None, + ) -> None: + """Delete state from in-memory store.""" + self._state.pop(key, None) + + def close(self) -> None: + """Close the client.""" + self._closed = True + + +@pytest.fixture +def fake_dapr_client() -> FakeDaprClient: + """Create fake Dapr client for testing.""" + return FakeDaprClient() + + +@pytest.fixture +def dapr_manager(fake_dapr_client: FakeDaprClient) -> DaprSessionManager: + """Create DaprSessionManager for testing.""" + return DaprSessionManager( + session_id="test", state_store_name="statestore", dapr_client=fake_dapr_client, consistency=DAPR_CONSISTENCY_EVENTUAL + ) + + +@pytest.fixture +def sample_session() -> Session: + """Create sample session for testing.""" + return Session(session_id="test-session", session_type=SessionType.AGENT) + + +@pytest.fixture +def sample_agent() -> SessionAgent: + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent", state={"key": "value"}, conversation_manager_state=NullConversationManager().get_state() + ) + + +@pytest.fixture +def sample_message() -> SessionMessage: + """Create sample message for testing.""" + return SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text="Hello world")], + }, + index=0, + ) + + +def test_consistency_constants(): + """Test consistency constants are properly defined.""" + assert DAPR_CONSISTENCY_EVENTUAL == "eventual" + assert DAPR_CONSISTENCY_STRONG == "strong" + + +def test_messages_shape_non_list_handling(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Seed a non-list messages payload and verify graceful handling and overwrite by create_message.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Manually write an invalid messages payload (object instead of list) + messages_key = dapr_manager._get_messages_key(sample_session.session_id, sample_agent.agent_id) + # Directly mutate internal client state for test + assert hasattr(dapr_manager._dapr_client, "_state") + dapr_manager._dapr_client._state[messages_key] = b'{"messages": {}}' + + # list_messages should return empty + assert dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id) == [] + + # Now create a message; this should overwrite messages with a proper list + new_msg = SessionMessage.from_message({"role": "user", "content": [ContentBlock(text="こんにちは 😃")]}, 0) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, new_msg) + + listed = dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + assert len(listed) == 1 + assert "こんにちは" in str(listed[0].message["content"]) # unicode round-trip + + +def test_init_with_consistency_levels(fake_dapr_client: FakeDaprClient): + """Test initialization with different consistency levels.""" + # Test eventual consistency + manager_eventual = DaprSessionManager( + session_id="test", + state_store_name="statestore", + dapr_client=fake_dapr_client, + consistency=DAPR_CONSISTENCY_EVENTUAL, + ) + assert manager_eventual._consistency == DAPR_CONSISTENCY_EVENTUAL + + # Test strong consistency + manager_strong = DaprSessionManager( + session_id="test", + state_store_name="statestore", + dapr_client=fake_dapr_client, + consistency=DAPR_CONSISTENCY_STRONG, + ) + assert manager_strong._consistency == DAPR_CONSISTENCY_STRONG + + +def test_init_with_ttl(fake_dapr_client: FakeDaprClient): + """Test initialization with TTL.""" + manager = DaprSessionManager( + session_id="test", state_store_name="statestore", dapr_client=fake_dapr_client, ttl=3600 + ) + assert manager._ttl == 3600 + + +def test_create_session(dapr_manager: DaprSessionManager, sample_session: Session): + """Test creating a session.""" + dapr_manager.create_session(sample_session) + + # Verify session was stored + result = dapr_manager.read_session(sample_session.session_id) + assert result is not None + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_create_session_already_exists(dapr_manager: DaprSessionManager, sample_session: Session): + """Test creating a session that already exists.""" + dapr_manager.create_session(sample_session) + + # Try to create again + with pytest.raises(SessionException, match="already exists"): + dapr_manager.create_session(sample_session) + + +def test_read_session(dapr_manager: DaprSessionManager, sample_session: Session): + """Test reading a session.""" + dapr_manager.create_session(sample_session) + result = dapr_manager.read_session(sample_session.session_id) + + assert result is not None + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(dapr_manager: DaprSessionManager): + """Test reading a session that doesn't exist.""" + result = dapr_manager.read_session("nonexistent-session") + assert result is None + + +def test_create_agent(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Test creating an agent in a session.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify agent was stored + result = dapr_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result is not None + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_agent(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Test reading an agent from a session.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + result = dapr_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result is not None + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(dapr_manager: DaprSessionManager, sample_session: Session): + """Test reading an agent that doesn't exist.""" + result = dapr_manager.read_agent(sample_session.session_id, "nonexistent_agent") + assert result is None + + +def test_update_agent(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Test updating an agent.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + dapr_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = dapr_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result is not None + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent( + dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent +): + """Test updating an agent that doesn't exist.""" + dapr_manager.create_session(sample_session) + + # Try to update non-existent agent + with pytest.raises(SessionException, match="does not exist"): + dapr_manager.update_agent(sample_session.session_id, sample_agent) + + +def test_create_message( + dapr_manager: DaprSessionManager, + sample_session: Session, + sample_agent: SessionAgent, + sample_message: SessionMessage, +): + """Test creating a message for an agent.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify message was stored + result = dapr_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result is not None + assert result.message_id == sample_message.message_id + + +def test_read_message( + dapr_manager: DaprSessionManager, + sample_session: Session, + sample_agent: SessionAgent, + sample_message: SessionMessage, +): + """Test reading a message.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Create additional message + sample_message_2 = SessionMessage.from_message( + message={ + "role": "assistant", + "content": [ContentBlock(text="Hi there")], + }, + index=1, + ) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message_2) + + # Read specific message + result = dapr_manager.read_message(sample_session.session_id, sample_agent.agent_id, 1) + + assert result is not None + assert result.message_id == 1 + assert result.message["role"] == "assistant" + + +def test_read_nonexistent_message( + dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent +): + """Test reading a message that doesn't exist.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + result = dapr_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) + assert result is None + + +def test_list_messages_all(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Test listing all messages for an agent.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(5): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + + +def test_list_messages_with_limit( + dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent +): + """Test listing messages with limit.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + + assert len(result) == 3 + + +def test_list_messages_with_offset( + dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent +): + """Test listing messages with offset.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with offset + result = dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + + assert len(result) == 5 + assert result[0].message_id == 5 + + +def test_list_messages_empty(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): + """Test listing messages when none exist.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + result = dapr_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + +def test_update_message( + dapr_manager: DaprSessionManager, + sample_session: Session, + sample_agent: SessionAgent, + sample_message: SessionMessage, +): + """Test updating a message.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + dapr_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + dapr_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify update + result = dapr_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result is not None + assert result.message["content"][0]["text"] == "Updated content" + + +def test_update_nonexistent_message( + dapr_manager: DaprSessionManager, + sample_session: Session, + sample_agent: SessionAgent, + sample_message: SessionMessage, +): + """Test updating a message that doesn't exist.""" + dapr_manager.create_session(sample_session) + dapr_manager.create_agent(sample_session.session_id, sample_agent) + + # Try to update non-existent message + with pytest.raises(SessionException, match="does not exist"): + dapr_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +def test_corrupted_json(dapr_manager: DaprSessionManager, fake_dapr_client: FakeDaprClient): + """Test handling of corrupted JSON data.""" + # Store invalid JSON (use string key to match how _get_session_key works) + fake_dapr_client._state["test:session"] = b"invalid json content" + + # Should raise SessionException + with pytest.raises(SessionException, match="Invalid JSON"): + dapr_manager.read_session("test") + + +@pytest.mark.parametrize( + "session_id", + [ + "a/../b", + "a/b", + ], +) +def test_invalid_session_id(session_id: str, fake_dapr_client: FakeDaprClient): + """Test that session IDs with path separators are rejected.""" + manager = DaprSessionManager(session_id="test", state_store_name="statestore", dapr_client=fake_dapr_client) + + with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"): + manager._get_session_key(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test_invalid_agent_id(agent_id: str, dapr_manager: DaprSessionManager): + """Test that agent IDs with path separators are rejected.""" + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + dapr_manager._get_agent_key("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test_invalid_message_id(message_id: Any, dapr_manager: DaprSessionManager, sample_session: Session): + """Test that non-integer message IDs are rejected.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + dapr_manager.read_message(sample_session.session_id, "agent1", message_id) + + +def test_client_ownership(fake_dapr_client: FakeDaprClient): + """Test client ownership tracking.""" + # Manager created with external client + manager = DaprSessionManager(session_id="test", state_store_name="statestore", dapr_client=fake_dapr_client) + assert manager._owns_client is False + + # Manager created with from_address + # We can't test actual client creation without mocking, but we can verify the flag + assert hasattr(manager, "_owns_client") + + +def test_close_owned_client(fake_dapr_client: FakeDaprClient): + """Test closing owned client.""" + # Create manager with existing session + fake_dapr_client._state["test-close:session"] = b'{"session_id": "test-close", "session_type": "AGENT"}' + manager = DaprSessionManager(session_id="test-close", state_store_name="statestore", dapr_client=fake_dapr_client) + manager._owns_client = True + + manager.close() + assert fake_dapr_client._closed is True + + +def test_close_unowned_client(fake_dapr_client: FakeDaprClient): + """Test not closing unowned client.""" + # Create manager with existing session + fake_dapr_client._state["test-close2:session"] = b'{"session_id": "test-close2", "session_type": "AGENT"}' + manager = DaprSessionManager(session_id="test-close2", state_store_name="statestore", dapr_client=fake_dapr_client) + manager._owns_client = False + + manager.close() + assert fake_dapr_client._closed is False + + +def test_delete_session_parity(fake_dapr_client: FakeDaprClient): + """Test delete_session removes session, agents, messages and manifest.""" + manager = DaprSessionManager(session_id="sess", state_store_name="statestore", dapr_client=fake_dapr_client) + # Create agent and messages (session already created by __init__ if needed) + agent = SessionAgent(agent_id="a1", state={}, conversation_manager_state={}) + manager.create_agent("sess", agent) + manager.create_message( + "sess", "a1", SessionMessage.from_message({"role": "user", "content": [ContentBlock(text="hi")]}, 0) + ) + + # Sanity + assert manager.read_session("sess") is not None + assert manager.read_agent("sess", "a1") is not None + assert len(manager.list_messages("sess", "a1")) == 1 + + # Delete + manager.delete_session("sess") + + # All gone + assert manager.read_session("sess") is None + assert manager.read_agent("sess", "a1") is None + assert manager.list_messages("sess", "a1") == [] diff --git a/tests_integ/test_dapr_session.py b/tests_integ/test_dapr_session.py new file mode 100644 index 000000000..f4d4bb237 --- /dev/null +++ b/tests_integ/test_dapr_session.py @@ -0,0 +1,428 @@ +"""Integration tests for DaprSessionManager with real Dapr and Redis.""" + +import os +import shutil +import tempfile +import time +import urllib.request +from typing import Any +from urllib.error import URLError +from uuid import uuid4 + +# pytestmark = [pytest.mark.asyncio] +import dotenv +import pytest +from testcontainers.core.container import DockerContainer # type: ignore[import-untyped] +from testcontainers.core.network import Network # type: ignore[import-untyped] +from testcontainers.redis import RedisContainer # type: ignore[import-untyped] + +from strands import Agent +from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.session.dapr_session_manager import DAPR_CONSISTENCY_STRONG, DaprSessionManager +from tests.fixtures.mocked_model_provider import MockedModelProvider + +dotenv.load_dotenv() + + +@pytest.fixture(scope="module") +def docker_network(): + """Create a Docker network for container-to-container communication.""" + network = Network() + network.create() + try: + yield network + finally: + try: + network.remove() + except Exception: + pass + + +@pytest.fixture(scope="module") +def redis_container(docker_network: Any) -> Any: + """Redis container on shared network with network alias.""" + container = RedisContainer("redis:7-alpine") + container = container.with_network(docker_network) + container = container.with_network_aliases("redis") + container.start() + yield container + try: + container.stop() + except Exception: + pass + + +@pytest.fixture(scope="module") +def dapr_container(redis_container: Any, docker_network: Any) -> Any: + """Dapr sidecar container with Redis state store.""" + # Create Dapr component config + temp_dir = tempfile.mkdtemp() + component_path = os.path.join(temp_dir, "statestore.yaml") + + with open(component_path, "w") as f: + f.write( + """apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: statestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: redis:6379 + - name: redisPassword + value: "" + - name: enableTLS + value: "false" +""" + ) + + # Start Dapr container + container = DockerContainer("daprio/daprd:latest") + container = container.with_network(docker_network) + container = container.with_volume_mapping(temp_dir, "/components", mode="ro") + container = container.with_command( + "./daprd " + "--app-id test-app " + "--dapr-grpc-port 50001 " + "--dapr-http-port 3500 " + "--components-path /components " + "--log-level debug" + ) + container = container.with_exposed_ports(50001, 3500) + container.start() + + # Wait for Dapr to be ready + http_host = container.get_container_host_ip() + http_port = container.get_exposed_port(3500) + if not _wait_for_dapr_health(http_host, http_port, timeout=60): + container.stop() + pytest.fail("Dapr container failed to become healthy") + + # Set environment variables for Dapr SDK + os.environ["DAPR_HTTP_PORT"] = str(http_port) + os.environ["DAPR_RUNTIME_HOST"] = http_host + + yield container + + container.stop() + os.environ.pop("DAPR_HTTP_PORT", None) + os.environ.pop("DAPR_RUNTIME_HOST", None) + shutil.rmtree(temp_dir, ignore_errors=True) + + +def _wait_for_dapr_health(host: str, port: int, timeout: int = 60) -> bool: + """Poll Dapr HTTP health endpoint until ready.""" + health_url = f"http://{host}:{port}/v1.0/healthz/outbound" + start_time = time.time() + print(f"Waiting for Dapr health at {health_url}") + + while time.time() - start_time < timeout: + try: + with urllib.request.urlopen(health_url, timeout=5) as response: + if 200 <= response.status < 300: + return True + print(f"Dapr health check failed with status {response.status}") + except URLError: + print(f"Dapr health check failed with URLError") + pass + except Exception as e: + print(f"Dapr health check failed with exception {e}") + print(f"Dapr health check failed with timeout") + time.sleep(1) + return False + + +def test_agent_with_dapr_session(dapr_container: Any, monkeypatch: Any): + """Test agent with DaprSessionManager using real Dapr and Redis.""" + # Bypass SDK's internal health check (already done in fixture) + from dapr.clients.health import DaprHealth + + monkeypatch.setattr(DaprHealth, "wait_until_ready", lambda: None) + + dapr_host = dapr_container.get_container_host_ip() + dapr_port = dapr_container.get_exposed_port(50001) + test_session_id = str(uuid4()) + + session_manager = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + + session_manager_2 = None + try: + # Use mocked model to avoid real provider calls + model1 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent = Agent(session_manager=session_manager, model=model1) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + model2 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent_2 = Agent(session_manager=session_manager_2, model=model2) + assert len(agent_2.messages) == 2 + agent_2("Hello again!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Cleanup + session_manager.close() + if session_manager_2 is not None: + session_manager_2.close() + + +def test_agent_with_dapr_session_and_conversation_manager(dapr_container: Any, monkeypatch: Any): + """Test agent with DaprSessionManager and SlidingWindowConversationManager.""" + from dapr.clients.health import DaprHealth + + monkeypatch.setattr(DaprHealth, "wait_until_ready", lambda: None) + + dapr_host = dapr_container.get_container_host_ip() + dapr_port = dapr_container.get_exposed_port(50001) + test_session_id = str(uuid4()) + + session_manager = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + + session_manager_2 = None + try: + model1 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent = Agent( + session_manager=session_manager, + model=model1, + conversation_manager=SlidingWindowConversationManager(window_size=1) + ) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + # Conversation Manager reduced messages + assert len(agent.messages) == 1 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + model2 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent_2 = Agent( + session_manager=session_manager_2, + model=model2, + conversation_manager=SlidingWindowConversationManager(window_size=1) + ) + assert len(agent_2.messages) == 1 + assert agent_2.conversation_manager.removed_message_count == 1 + agent_2("Hello again!") + assert len(agent_2.messages) == 1 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Cleanup + session_manager.close() + if session_manager_2 is not None: + session_manager_2.close() + + +def test_agent_with_dapr_session_with_image(dapr_container: Any, yellow_img: bytes, monkeypatch: Any): + """Test agent with DaprSessionManager handling image content.""" + from dapr.clients.health import DaprHealth + + monkeypatch.setattr(DaprHealth, "wait_until_ready", lambda: None) + + dapr_host = dapr_container.get_container_host_ip() + dapr_port = dapr_container.get_exposed_port(50001) + test_session_id = str(uuid4()) + + session_manager = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + + session_manager_2 = None + try: + model1 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent = Agent(session_manager=session_manager, model=model1) + agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + model2 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent_2 = Agent(session_manager=session_manager_2, model=model2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Cleanup + session_manager.close() + if session_manager_2 is not None: + session_manager_2.close() + + +def test_agent_with_dapr_session_forced_summarization(dapr_container: Any, monkeypatch: Any): + """Force summarization via SummarizingConversationManager and verify persistence/restoration with Dapr.""" + from dapr.clients.health import DaprHealth + + monkeypatch.setattr(DaprHealth, "wait_until_ready", lambda: None) + + dapr_host = dapr_container.get_container_host_ip() + dapr_port = dapr_container.get_exposed_port(50001) + test_session_id = str(uuid4()) + + session_manager = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + + session_manager_2 = None + try: + # Use a separate summarizer Agent without a session manager to avoid persisting summarization messages + summarizer_agent = Agent(model=MockedModelProvider([ + {"role": "assistant", "content": [{"text": "Summary"}]} + ])) + convo_manager = SummarizingConversationManager( + summarization_agent=summarizer_agent, summary_ratio=0.5, preserve_recent_messages=1 + ) + model1 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent = Agent(session_manager=session_manager, conversation_manager=convo_manager, model=model1) + + # Add enough messages + agent("m1") + agent("m2") + agent("m3") + + # Explicitly trigger summarization and persist the updated state + agent.conversation_manager.reduce_context(agent) + session_manager.sync_agent(agent) + + # Validate summary inserted and state updated + assert agent.conversation_manager.removed_message_count > 0 + assert agent.messages[0]["role"] == "user" + assert "Summary" in str(agent.messages[0]["content"]) # summary message + + # Restore with a new manager and agent + session_manager_2 = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + agent_2 = Agent(session_manager=session_manager_2, conversation_manager=convo_manager) + + # After restore, messages should reflect trimmed history (summary + remaining) + assert len(agent_2.messages) <= 4 + assert agent_2.conversation_manager.removed_message_count == agent.conversation_manager.removed_message_count + finally: + # Cleanup + session_manager.delete_session(test_session_id) + if session_manager_2 is not None: + session_manager_2.delete_session(test_session_id) + session_manager.close() + if session_manager_2 is not None: + session_manager_2.close() + + +def test_agent_with_dapr_session_and_summarizing_conversation_manager(dapr_container: Any, monkeypatch: Any): + """Test agent with DaprSessionManager and SummarizingConversationManager.""" + from dapr.clients.health import DaprHealth + + monkeypatch.setattr(DaprHealth, "wait_until_ready", lambda: None) + + dapr_host = dapr_container.get_container_host_ip() + dapr_port = dapr_container.get_exposed_port(50001) + test_session_id = str(uuid4()) + + session_manager = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + + session_manager_2 = None + try: + # Create agent with summarizing conversation manager + model1 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent = Agent( + session_manager=session_manager, + model=model1, + conversation_manager=SummarizingConversationManager(summary_ratio=0.5, preserve_recent_messages=2), + ) + agent("Hello!") + messages_count = len(session_manager.list_messages(test_session_id, agent.agent_id)) + assert messages_count == 2 + + # Restore the agent with the same conversation manager + session_manager_2 = DaprSessionManager.from_address( + session_id=test_session_id, + state_store_name="statestore", + dapr_address=f"{dapr_host}:{dapr_port}", + consistency=DAPR_CONSISTENCY_STRONG, + ) + model2 = MockedModelProvider([ + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ]) + agent_2 = Agent( + session_manager=session_manager_2, + model=model2, + conversation_manager=SummarizingConversationManager(summary_ratio=0.5, preserve_recent_messages=2), + ) + + # Verify state was restored correctly + assert len(agent_2.messages) == 2 + assert isinstance(agent_2.conversation_manager, SummarizingConversationManager) + + # Add more messages to trigger summarization if needed + agent_2("Tell me a story") + agent_2("Continue the story") + + # Verify messages were persisted + final_messages_count = len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) + assert final_messages_count >= 4 + finally: + # Cleanup + session_manager.close() + if session_manager_2 is not None: + session_manager_2.close() From 3fc0e92aa5794e2e0bf162dfb3f775fe9b5420d0 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:23:39 -0500 Subject: [PATCH 2/4] feat: add dapr session Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- src/strands/session/__init__.py | 6 +- .../session/test_dapr_session_manager.py | 10 ++- tests_integ/test_dapr_session.py | 90 +++++++++++-------- 3 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/strands/session/__init__.py b/src/strands/session/__init__.py index a8b58a717..88e4733de 100644 --- a/src/strands/session/__init__.py +++ b/src/strands/session/__init__.py @@ -32,7 +32,7 @@ def __getattr__(name: str) -> Any: return DaprSessionManager except ModuleNotFoundError as e: raise ImportError( - "DaprSessionManager requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + "DaprSessionManager requires the 'dapr' extra. Install it with: pip install strands-agents[dapr]" ) from e if name == "DAPR_CONSISTENCY_EVENTUAL": @@ -42,7 +42,7 @@ def __getattr__(name: str) -> Any: return DAPR_CONSISTENCY_EVENTUAL except ModuleNotFoundError as e: raise ImportError( - "DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + "DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. Install it with: pip install strands-agents[dapr]" ) from e if name == "DAPR_CONSISTENCY_STRONG": @@ -52,7 +52,7 @@ def __getattr__(name: str) -> Any: return DAPR_CONSISTENCY_STRONG except ModuleNotFoundError as e: raise ImportError( - "DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. " "Install it with: pip install strands-agents[dapr]" + "DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. Install it with: pip install strands-agents[dapr]" ) from e raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/strands/session/test_dapr_session_manager.py b/tests/strands/session/test_dapr_session_manager.py index 2a216ed20..f68c651d1 100644 --- a/tests/strands/session/test_dapr_session_manager.py +++ b/tests/strands/session/test_dapr_session_manager.py @@ -4,6 +4,7 @@ from unittest.mock import Mock import pytest + from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.session.dapr_session_manager import ( DAPR_CONSISTENCY_EVENTUAL, @@ -74,7 +75,10 @@ def fake_dapr_client() -> FakeDaprClient: def dapr_manager(fake_dapr_client: FakeDaprClient) -> DaprSessionManager: """Create DaprSessionManager for testing.""" return DaprSessionManager( - session_id="test", state_store_name="statestore", dapr_client=fake_dapr_client, consistency=DAPR_CONSISTENCY_EVENTUAL + session_id="test", + state_store_name="statestore", + dapr_client=fake_dapr_client, + consistency=DAPR_CONSISTENCY_EVENTUAL, ) @@ -110,7 +114,9 @@ def test_consistency_constants(): assert DAPR_CONSISTENCY_STRONG == "strong" -def test_messages_shape_non_list_handling(dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent): +def test_messages_shape_non_list_handling( + dapr_manager: DaprSessionManager, sample_session: Session, sample_agent: SessionAgent +): """Seed a non-list messages payload and verify graceful handling and overwrite by create_message.""" dapr_manager.create_session(sample_session) dapr_manager.create_agent(sample_session.session_id, sample_agent) diff --git a/tests_integ/test_dapr_session.py b/tests_integ/test_dapr_session.py index f4d4bb237..686221b30 100644 --- a/tests_integ/test_dapr_session.py +++ b/tests_integ/test_dapr_session.py @@ -126,11 +126,11 @@ def _wait_for_dapr_health(host: str, port: int, timeout: int = 60) -> bool: return True print(f"Dapr health check failed with status {response.status}") except URLError: - print(f"Dapr health check failed with URLError") + print("Dapr health check failed with URLError") pass except Exception as e: print(f"Dapr health check failed with exception {e}") - print(f"Dapr health check failed with timeout") + print("Dapr health check failed with timeout") time.sleep(1) return False @@ -156,9 +156,11 @@ def test_agent_with_dapr_session(dapr_container: Any, monkeypatch: Any): session_manager_2 = None try: # Use mocked model to avoid real provider calls - model1 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model1 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent = Agent(session_manager=session_manager, model=model1) agent("Hello!") assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 @@ -170,9 +172,11 @@ def test_agent_with_dapr_session(dapr_container: Any, monkeypatch: Any): dapr_address=f"{dapr_host}:{dapr_port}", consistency=DAPR_CONSISTENCY_STRONG, ) - model2 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model2 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent_2 = Agent(session_manager=session_manager_2, model=model2) assert len(agent_2.messages) == 2 agent_2("Hello again!") @@ -204,13 +208,15 @@ def test_agent_with_dapr_session_and_conversation_manager(dapr_container: Any, m session_manager_2 = None try: - model1 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model1 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent = Agent( session_manager=session_manager, model=model1, - conversation_manager=SlidingWindowConversationManager(window_size=1) + conversation_manager=SlidingWindowConversationManager(window_size=1), ) agent("Hello!") assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 @@ -224,13 +230,15 @@ def test_agent_with_dapr_session_and_conversation_manager(dapr_container: Any, m dapr_address=f"{dapr_host}:{dapr_port}", consistency=DAPR_CONSISTENCY_STRONG, ) - model2 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model2 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent_2 = Agent( session_manager=session_manager_2, model=model2, - conversation_manager=SlidingWindowConversationManager(window_size=1) + conversation_manager=SlidingWindowConversationManager(window_size=1), ) assert len(agent_2.messages) == 1 assert agent_2.conversation_manager.removed_message_count == 1 @@ -263,9 +271,11 @@ def test_agent_with_dapr_session_with_image(dapr_container: Any, yellow_img: byt session_manager_2 = None try: - model1 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model1 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent = Agent(session_manager=session_manager, model=model1) agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 @@ -277,9 +287,11 @@ def test_agent_with_dapr_session_with_image(dapr_container: Any, yellow_img: byt dapr_address=f"{dapr_host}:{dapr_port}", consistency=DAPR_CONSISTENCY_STRONG, ) - model2 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model2 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent_2 = Agent(session_manager=session_manager_2, model=model2) assert len(agent_2.messages) == 2 agent_2("Hello!") @@ -312,17 +324,17 @@ def test_agent_with_dapr_session_forced_summarization(dapr_container: Any, monke session_manager_2 = None try: # Use a separate summarizer Agent without a session manager to avoid persisting summarization messages - summarizer_agent = Agent(model=MockedModelProvider([ - {"role": "assistant", "content": [{"text": "Summary"}]} - ])) + summarizer_agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "Summary"}]}])) convo_manager = SummarizingConversationManager( summarization_agent=summarizer_agent, summary_ratio=0.5, preserve_recent_messages=1 ) - model1 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model1 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent = Agent(session_manager=session_manager, conversation_manager=convo_manager, model=model1) # Add enough messages @@ -381,9 +393,11 @@ def test_agent_with_dapr_session_and_summarizing_conversation_manager(dapr_conta session_manager_2 = None try: # Create agent with summarizing conversation manager - model1 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model1 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent = Agent( session_manager=session_manager, model=model1, @@ -400,10 +414,12 @@ def test_agent_with_dapr_session_and_summarizing_conversation_manager(dapr_conta dapr_address=f"{dapr_host}:{dapr_port}", consistency=DAPR_CONSISTENCY_STRONG, ) - model2 = MockedModelProvider([ - {"role": "assistant", "content": [{"text": "ok"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ]) + model2 = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "ok"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ] + ) agent_2 = Agent( session_manager=session_manager_2, model=model2, From d2a65df76b935242baf539a4da3d849e2e650191 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:47:42 -0500 Subject: [PATCH 3/4] chore: update dapr dependency to version 1.16.0 and remove unused state_metadata parameter in DaprSessionManager Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- pyproject.toml | 2 +- src/strands/session/dapr_session_manager.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 060e2f5da..95ab4bcfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ sagemaker = [ "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] -dapr = ["dapr>=1.14.0", "grpcio>=1.60.0"] +dapr = ["dapr>=1.16.0", "grpcio>=1.60.0"] docs = [ "sphinx>=5.0.0,<9.0.0", "sphinx-rtd-theme>=1.0.0,<2.0.0", diff --git a/src/strands/session/dapr_session_manager.py b/src/strands/session/dapr_session_manager.py index 18eef1e79..87d16fa9e 100644 --- a/src/strands/session/dapr_session_manager.py +++ b/src/strands/session/dapr_session_manager.py @@ -233,7 +233,6 @@ def _delete_state(self, key: str) -> None: self._dapr_client.delete_state( store_name=self._state_store_name, key=key, - state_metadata=self._get_read_metadata(), options=self._get_state_options(), ) except Exception as e: From 69e88e243b2b3fa717734bbd415c7701cafbabee Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Wed, 22 Oct 2025 21:52:41 -0500 Subject: [PATCH 4/4] chore: update dpar version in pyproject Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6058b34b9..83bd47958 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ dependencies = [ "pytest-asyncio>=1.0.0,<1.3.0", "pytest-xdist>=3.0.0,<4.0.0", "moto>=5.1.0,<6.0.0", - "dapr>=1.14.0", + "dapr>=1.16.0", "grpcio>=1.60.0", "testcontainers[redis]>=4.0.0", ]