diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7a8c0ae03..fa03e89a6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,9 +14,10 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException -from ..types.streaming import StreamEvent +from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .openai import OpenAIModel @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a LiteLLM content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: LiteLLM formatted content block. @@ -131,6 +133,111 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> return chunks, data_type + @override + @classmethod + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + ) -> list[dict[str, Any]]: + """Format system messages for LiteLLM with cache point support. + + Args: + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + + Returns: + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + system_content: list[dict[str, Any]] = [] + for block in system_prompt_content or []: + if "text" in block: + system_content.append({"type": "text", "text": block["text"]}) + elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + # Apply cache control to the immediately preceding content block + # for LiteLLM/Anthropic compatibility + if system_content: + system_content[-1]["cache_control"] = {"type": "ephemeral"} + + # Create single system message with content array rather than mulitple system messages + return [{"role": "system", "content": system_content}] if system_content else [] + + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: + """Format a LiteLLM response event into a standardized message chunk. + + This method overrides OpenAI's format_chunk to handle the metadata case + with prompt caching support. All other chunk types use the parent implementation. + + Args: + event: A response event from the LiteLLM model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + # Handle metadata case with prompt caching support + if event["chunk_type"] == "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + # Only LiteLLM over Anthropic supports cache cache write tokens + # Waiting until a more general approach is available to set cacheWriteInputTokens + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + if creation := getattr(tokens_details, "cache_creation_tokens", None): + usage_data["cacheWriteInputTokens"] = creation + + return StreamEvent( + metadata=MetadataEvent( + metrics={ + "latencyMs": 0, # TODO + }, + usage=usage_data, + ) + ) + # For all other cases, use the parent implementation + return super().format_chunk(event) + @override async def stream( self, @@ -139,6 +246,7 @@ async def stream( system_prompt: Optional[str] = None, *, tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -148,13 +256,16 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice) + request = self.format_request( + messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content + ) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1efe641e6..f34748369 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -89,11 +89,12 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible content block. @@ -131,11 +132,12 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod - def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool call. @@ -150,11 +152,12 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: } @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool message. @@ -198,18 +201,43 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + ) -> list[dict[str, Any]]: + """Format system messages for OpenAI-compatible providers. Args: - messages: List of message objects to be processed by the model. system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. Returns: - An OpenAI compatible messages array. + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + # TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140 + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content + ] + + @classmethod + def _format_regular_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format regular messages for OpenAI-compatible providers. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + List of formatted messages. """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + formatted_messages = [] for message in messages: contents = message["content"] @@ -242,14 +270,42 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) + return formatted_messages + + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -258,6 +314,8 @@ def format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: An OpenAI compatible chat streaming request. @@ -267,7 +325,9 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self.format_request_messages( + messages, system_prompt, system_prompt_content=system_prompt_content + ), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -286,11 +346,12 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: event: A response event from the OpenAI compatible model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 25b3ca7ce..7f8b8ff51 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -202,6 +202,7 @@ def format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -211,6 +212,7 @@ def format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for interface consistency but is currently ignored for this model provider.** + **kwargs: Additional keyword arguments for future extensibility. Returns: An Amazon SageMaker chat streaming request. @@ -501,11 +503,12 @@ async def stream( @override @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format a SageMaker compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: SageMaker compatible tool message with content as a string. @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: Formatted content block. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 57a8593cd..f56438cf5 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -192,6 +192,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) mock_event_9 = unittest.mock.Mock() + mock_event_9.usage.prompt_tokens_details.cached_tokens = 10 + mock_event_9.usage.prompt_tokens_details.cache_creation_tokens = 10 litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator( @@ -252,6 +254,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { + "cacheReadInputTokens": mock_event_9.usage.prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": mock_event_9.usage.prompt_tokens_details.cache_creation_tokens, "inputTokens": mock_event_9.usage.prompt_tokens, "outputTokens": mock_event_9.usage.completion_tokens, "totalTokens": mock_event_9.usage.total_tokens, @@ -402,3 +406,65 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model with pytest.raises(ContextWindowOverflowException): async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): pass + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant.", "cache_control": {"type": "ephemeral"}} + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_backward_compatibility_system_prompt(): + """Test that system_prompt is converted to system_prompt_content when system_prompt_content is None.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + result = LiteLLMModel.format_request_messages(messages, system_prompt=system_prompt) + + expected = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_cache_point_support(): + """Test that cache points are properly applied to preceding content blocks.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [ + {"text": "First instruction."}, + {"text": "Second instruction."}, + {"cachePoint": {"type": "default"}}, + {"text": "Third instruction."}, + ] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "First instruction."}, + {"type": "text", "text": "Second instruction.", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Third instruction."}, + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index cc30b7420..0de0c4ebc 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -944,3 +944,45 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me # Verify the exception message contains the original error assert "tokens per min" in str(exc_info.value) assert exc_info.value.__cause__ == mock_error + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_with_none_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + result = OpenAIModel.format_request_messages(messages) + + expected = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + + assert result == expected + + +def test_format_request_messages_drops_cache_points(): + """Test that cache points are dropped in OpenAI format_request_messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + # Cache points should be dropped, only text content included + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b348c29f4..f177c08a4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -211,3 +211,25 @@ def test_structured_output_unsupported_model(model, nested_weather): # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() mock_schema.assert_not_called() + + +@pytest.mark.asyncio +async def test_cache_read_tokens_multi_turn(model): + """Integration test for cache read tokens in multi-turn conversation.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + # Caching only works when prompts are large + {"text": "You are a helpful assistant. Always be concise." * 200}, + {"cachePoint": {"type": "default"}}, + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + + # First turn - establishes cache + agent("Hello, what's 2+2?") + result = agent("What's 3+3?") + result.metrics.accumulated_usage["cacheReadInputTokens"] + + assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 + assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 7beb3013c..feb591d1a 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -231,3 +231,29 @@ def test_content_blocks_handling(model): result = agent(content) assert "4" in result.message["content"][0]["text"] + + +def test_system_prompt_content_integration(model): + """Integration test for system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Integration test for backward compatibility with system_prompt parameter.""" + system_prompt = "You are a helpful assistant that always responds with 'BACKWARD_COMPAT_TEST'." + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"]