From 8d99df5b2fcd18296f74eb9d525a5021f23a08bd Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 5 Nov 2025 11:04:19 -0500 Subject: [PATCH 1/4] litellm prompt caching --- src/strands/models/litellm.py | 119 +++++++++++++++++++++++++++++++++- src/strands/models/openai.py | 45 ++++++++++--- 2 files changed, 152 insertions(+), 12 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7a8c0ae03..1e82ab5c5 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,7 +14,7 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -131,6 +131,120 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> return chunks, data_type + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"context": system_prompt}] + + # For LiteLLM with Bedrock, we can support cache points + system_content = [] + for block in system_prompt_content: + if "text" in block: + system_content.append({"type": "text", "text": block["text"]}) + elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + # Apply cache control to the immediately preceding content block + # for LiteLLM/Anthropic compatibility + if system_content: + system_content[-1]["cache_control"] = {"type": "ephemeral"} + + # Create single system message with content array + if system_content: + formatted_messages.append({"role": "system", "content": system_content}) + + # Process regular messages + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a LiteLLM response event into a standardized message chunk. + + This method overrides OpenAI's format_chunk to handle the metadata case + with prompt caching support. All other chunk types use the parent implementation. + + Args: + event: A response event from the LiteLLM model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + # Handle metadata case with prompt caching support + if event["chunk_type"] == "metadata": + usage_data = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + # Only LiteLLM over Anthropic supports cache cache write tokens + # Waiting until a more general approach is available to set cacheWriteInputTokens + + tokens_details = getattr(event["data"], "prompt_tokens_details", None) + if tokens_details and getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = event["data"].prompt_tokens_details.cached_tokens + + + + return { + "metadata": { + "usage": usage_data, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + # For all other cases, use the parent implementation + return super().format_chunk(event) + @override async def stream( self, @@ -139,6 +253,7 @@ async def stream( system_prompt: Optional[str] = None, *, tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -154,7 +269,7 @@ async def stream( Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1efe641e6..24d42e84b 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -89,7 +89,7 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: @@ -131,7 +131,7 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod - def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: @@ -150,7 +150,7 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: } @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: @@ -198,7 +198,14 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs + ) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. Args: @@ -208,8 +215,22 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str Returns: An OpenAI compatible messages array. """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"context": system_prompt}] + + # TODO: Handle caching blocks in openai + # TODO Create tracking ticket + formatted_messages: list[dict[str, Any]] = [ + { + "role": "system", + "content": [ + cls.format_request_message_content(content) + for content in system_prompt_content + if "text" in content + ], + } + ] for message in messages: contents = message["content"] @@ -247,9 +268,12 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -267,7 +291,7 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self.format_request_messages(messages, system_prompt, system_prompt_content=system_prompt_content), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -286,7 +310,8 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: From 264e137d81481a3421691ad7ca14b6005401f94f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 5 Nov 2025 14:02:19 -0500 Subject: [PATCH 2/4] feat(models): allow SystemContentBlocks in LiteLLMModel --- src/strands/models/litellm.py | 111 +++++++++++------------ src/strands/models/openai.py | 84 ++++++++++++----- src/strands/models/sagemaker.py | 8 +- tests/strands/agent/test_agent.py | 4 +- tests/strands/models/test_litellm.py | 66 ++++++++++++++ tests/strands/models/test_openai.py | 42 +++++++++ tests_integ/models/test_model_litellm.py | 22 +++++ tests_integ/models/test_model_openai.py | 26 ++++++ 8 files changed, 278 insertions(+), 85 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1e82ab5c5..52a47057c 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -15,8 +15,9 @@ from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException -from ..types.streaming import StreamEvent +from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .openai import OpenAIModel @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a LiteLLM content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: LiteLLM formatted content block. @@ -133,33 +135,28 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @override @classmethod - def format_request_messages( + def _format_system_messages( cls, - messages: Messages, system_prompt: Optional[str] = None, *, system_prompt_content: Optional[list[SystemContentBlock]] = None, - **kwargs: Any, ) -> list[dict[str, Any]]: - """Format a LiteLLM compatible messages array with cache point support. + """Format system messages for LiteLLM with cache point support. Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt: System prompt to provide context to the model. system_prompt_content: System prompt content blocks to provide context to the model. - **kwargs: Additional keyword arguments for future extensibility. Returns: - A LiteLLM compatible messages array. + List of formatted system messages. """ - formatted_messages: list[dict[str, Any]] = [] # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: - system_prompt_content = [{"context": system_prompt}] + system_prompt_content = [{"text": system_prompt}] # For LiteLLM with Bedrock, we can support cache points - system_content = [] - for block in system_prompt_content: + system_content: list[dict[str, Any]] = [] + for block in system_prompt_content or []: if "text" in block: system_content.append({"type": "text", "text": block["text"]}) elif "cachePoint" in block and block["cachePoint"].get("type") == "default": @@ -169,39 +166,36 @@ def format_request_messages( system_content[-1]["cache_control"] = {"type": "ephemeral"} # Create single system message with content array - if system_content: - formatted_messages.append({"role": "system", "content": system_content}) - - # Process regular messages - for message in messages: - contents = message["content"] - - formatted_contents = [ - cls.format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - formatted_tool_calls = [ - cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content - ] - formatted_tool_messages = [ - cls.format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents, - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) + return [{"role": "system", "content": system_content}] if system_content else [] + + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) return [message for message in formatted_messages if message["content"] or "tool_calls" in message] @override - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format a LiteLLM response event into a standardized message chunk. This method overrides OpenAI's format_chunk to handle the metadata case @@ -209,6 +203,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: Args: event: A response event from the LiteLLM model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. @@ -218,7 +213,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """ # Handle metadata case with prompt caching support if event["chunk_type"] == "metadata": - usage_data = { + usage_data: Usage = { "inputTokens": event["data"].prompt_tokens, "outputTokens": event["data"].completion_tokens, "totalTokens": event["data"].total_tokens, @@ -226,22 +221,21 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: # Only LiteLLM over Anthropic supports cache cache write tokens # Waiting until a more general approach is available to set cacheWriteInputTokens - - tokens_details = getattr(event["data"], "prompt_tokens_details", None) - if tokens_details and getattr(tokens_details, "cached_tokens", None): - usage_data["cacheReadInputTokens"] = event["data"].prompt_tokens_details.cached_tokens - + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + if creation := getattr(tokens_details, "cache_creation_tokens", None): + usage_data["cacheWriteInputTokens"] = creation - return { - "metadata": { - "usage": usage_data, - "metrics": { + return StreamEvent( + metadata=MetadataEvent( + metrics={ "latencyMs": 0, # TODO }, - }, - } - + usage=usage_data, + ) + ) # For all other cases, use the parent implementation return super().format_chunk(event) @@ -263,13 +257,16 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content) + request = self.format_request( + messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content + ) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 24d42e84b..f34748369 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -94,6 +94,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible content block. @@ -136,6 +137,7 @@ def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> d Args: tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool call. @@ -155,6 +157,7 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool message. @@ -198,40 +201,44 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages( - cls, - messages: Messages, - system_prompt: Optional[str] = None, + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, *, system_prompt_content: Optional[list[SystemContentBlock]] = None, - **kwargs ) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. + """Format system messages for OpenAI-compatible providers. Args: - messages: List of message objects to be processed by the model. system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. Returns: - An OpenAI compatible messages array. + List of formatted system messages. """ # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: - system_prompt_content = [{"context": system_prompt}] - - # TODO: Handle caching blocks in openai - # TODO Create tracking ticket - formatted_messages: list[dict[str, Any]] = [ - { - "role": "system", - "content": [ - cls.format_request_message_content(content) - for content in system_prompt_content - if "text" in content - ], - } + system_prompt_content = [{"text": system_prompt}] + + # TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140 + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content ] + @classmethod + def _format_regular_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format regular messages for OpenAI-compatible providers. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + List of formatted messages. + """ + formatted_messages = [] + for message in messages: contents = message["content"] @@ -263,6 +270,31 @@ def format_request_messages( formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) + return formatted_messages + + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( @@ -282,6 +314,8 @@ def format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: An OpenAI compatible chat streaming request. @@ -291,7 +325,9 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt, system_prompt_content=system_prompt_content), + "messages": self.format_request_messages( + messages, system_prompt, system_prompt_content=system_prompt_content + ), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -310,12 +346,12 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: event: A response event from the OpenAI compatible model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 25b3ca7ce..7f8b8ff51 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -202,6 +202,7 @@ def format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -211,6 +212,7 @@ def format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for interface consistency but is currently ignored for this model provider.** + **kwargs: Additional keyword arguments for future extensibility. Returns: An Amazon SageMaker chat streaming request. @@ -501,11 +503,12 @@ async def stream( @override @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format a SageMaker compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: SageMaker compatible tool message with content as a string. @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: Formatted content block. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 3a0bc2dfb..550422cfe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2240,8 +2240,8 @@ def test_agent_backwards_compatibility_single_text_block(): # Should extract text for backwards compatibility assert agent.system_prompt == text - - + + @pytest.mark.parametrize( "content, expected", [ diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 57a8593cd..f56438cf5 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -192,6 +192,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) mock_event_9 = unittest.mock.Mock() + mock_event_9.usage.prompt_tokens_details.cached_tokens = 10 + mock_event_9.usage.prompt_tokens_details.cache_creation_tokens = 10 litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator( @@ -252,6 +254,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { + "cacheReadInputTokens": mock_event_9.usage.prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": mock_event_9.usage.prompt_tokens_details.cache_creation_tokens, "inputTokens": mock_event_9.usage.prompt_tokens, "outputTokens": mock_event_9.usage.completion_tokens, "totalTokens": mock_event_9.usage.total_tokens, @@ -402,3 +406,65 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model with pytest.raises(ContextWindowOverflowException): async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): pass + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant.", "cache_control": {"type": "ephemeral"}} + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_backward_compatibility_system_prompt(): + """Test that system_prompt is converted to system_prompt_content when system_prompt_content is None.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + result = LiteLLMModel.format_request_messages(messages, system_prompt=system_prompt) + + expected = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_cache_point_support(): + """Test that cache points are properly applied to preceding content blocks.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [ + {"text": "First instruction."}, + {"text": "Second instruction."}, + {"cachePoint": {"type": "default"}}, + {"text": "Third instruction."}, + ] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "First instruction."}, + {"type": "text", "text": "Second instruction.", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Third instruction."}, + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index cc30b7420..0de0c4ebc 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -944,3 +944,45 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me # Verify the exception message contains the original error assert "tokens per min" in str(exc_info.value) assert exc_info.value.__cause__ == mock_error + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_with_none_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + result = OpenAIModel.format_request_messages(messages) + + expected = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + + assert result == expected + + +def test_format_request_messages_drops_cache_points(): + """Test that cache points are dropped in OpenAI format_request_messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + # Cache points should be dropped, only text content included + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b348c29f4..f177c08a4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -211,3 +211,25 @@ def test_structured_output_unsupported_model(model, nested_weather): # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() mock_schema.assert_not_called() + + +@pytest.mark.asyncio +async def test_cache_read_tokens_multi_turn(model): + """Integration test for cache read tokens in multi-turn conversation.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + # Caching only works when prompts are large + {"text": "You are a helpful assistant. Always be concise." * 200}, + {"cachePoint": {"type": "default"}}, + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + + # First turn - establishes cache + agent("Hello, what's 2+2?") + result = agent("What's 3+3?") + result.metrics.accumulated_usage["cacheReadInputTokens"] + + assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 + assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 7beb3013c..6c65d0240 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -231,3 +231,29 @@ def test_content_blocks_handling(model): result = agent(content) assert "4" in result.message["content"][0]["text"] + + +def test_system_prompt_content_integration(model): + """Integration test for system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} + ] + + agent = Agent(model=model, system_prompt_content=system_prompt_content) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Integration test for backward compatibility with system_prompt parameter.""" + system_prompt = "You are a helpful assistant that always responds with 'BACKWARD_COMPAT_TEST'." + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] From 8b4b7ad6f4f896ac1c3af64f727c316c4edc2298 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 5 Nov 2025 14:08:23 -0500 Subject: [PATCH 3/4] cleanup pr --- src/strands/models/litellm.py | 3 +-- tests/strands/agent/test_agent.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 52a47057c..fa03e89a6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -154,7 +154,6 @@ def _format_system_messages( if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] - # For LiteLLM with Bedrock, we can support cache points system_content: list[dict[str, Any]] = [] for block in system_prompt_content or []: if "text" in block: @@ -165,7 +164,7 @@ def _format_system_messages( if system_content: system_content[-1]["cache_control"] = {"type": "ephemeral"} - # Create single system message with content array + # Create single system message with content array rather than mulitple system messages return [{"role": "system", "content": system_content}] if system_content else [] @override diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 550422cfe..3a0bc2dfb 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2240,8 +2240,8 @@ def test_agent_backwards_compatibility_single_text_block(): # Should extract text for backwards compatibility assert agent.system_prompt == text - - + + @pytest.mark.parametrize( "content, expected", [ From 5daebc5ab68699d26eecefe34a716cbb641b9789 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 5 Nov 2025 14:16:40 -0500 Subject: [PATCH 4/4] fix test --- tests_integ/models/test_model_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 6c65d0240..feb591d1a 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -241,7 +241,7 @@ def test_system_prompt_content_integration(model): {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} ] - agent = Agent(model=model, system_prompt_content=system_prompt_content) + agent = Agent(model=model, system_prompt=system_prompt_content) result = agent("Hello") # The response should contain our specific system prompt instruction