Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing_extensions import Unpack, override

from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.content import ContentBlock, Messages, SystemContentBlock
from ..types.event_loop import Usage
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import StreamEvent
from ..types.streaming import MetadataEvent, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys
from .openai import OpenAIModel
Expand Down Expand Up @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig:

@override
@classmethod
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]:
"""Format a LiteLLM content block.

Args:
content: Message content.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
LiteLLM formatted content block.
Expand Down Expand Up @@ -131,6 +133,111 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) ->

return chunks, data_type

@override
@classmethod
def _format_system_messages(
cls,
system_prompt: Optional[str] = None,
*,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
) -> list[dict[str, Any]]:
"""Format system messages for LiteLLM with cache point support.

Args:
system_prompt: System prompt to provide context to the model.
system_prompt_content: System prompt content blocks to provide context to the model.

Returns:
List of formatted system messages.
"""
# Handle backward compatibility: if system_prompt is provided but system_prompt_content is None
if system_prompt and system_prompt_content is None:
system_prompt_content = [{"text": system_prompt}]

system_content: list[dict[str, Any]] = []
for block in system_prompt_content or []:
if "text" in block:
system_content.append({"type": "text", "text": block["text"]})
elif "cachePoint" in block and block["cachePoint"].get("type") == "default":
# Apply cache control to the immediately preceding content block
# for LiteLLM/Anthropic compatibility
if system_content:
system_content[-1]["cache_control"] = {"type": "ephemeral"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know the other types? What if we set type to block["cachePoint"].get("type", "ephemeral")?


# Create single system message with content array rather than mulitple system messages
return [{"role": "system", "content": system_content}] if system_content else []

@override
@classmethod
def format_request_messages(
cls,
messages: Messages,
system_prompt: Optional[str] = None,
*,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Format a LiteLLM compatible messages array with cache point support.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model (for legacy compatibility).
system_prompt_content: System prompt content blocks to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
A LiteLLM compatible messages array.
"""
formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content)
formatted_messages.extend(cls._format_regular_messages(messages))

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

@override
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format a LiteLLM response event into a standardized message chunk.

This method overrides OpenAI's format_chunk to handle the metadata case
with prompt caching support. All other chunk types use the parent implementation.

Args:
event: A response event from the LiteLLM model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
The formatted chunk.

Raises:
RuntimeError: If chunk_type is not recognized.
"""
# Handle metadata case with prompt caching support
if event["chunk_type"] == "metadata":
usage_data: Usage = {
"inputTokens": event["data"].prompt_tokens,
"outputTokens": event["data"].completion_tokens,
"totalTokens": event["data"].total_tokens,
}

# Only LiteLLM over Anthropic supports cache cache write tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "cache cache ..."

# Waiting until a more general approach is available to set cacheWriteInputTokens

if tokens_details := getattr(event["data"], "prompt_tokens_details", None):
if cached := getattr(tokens_details, "cached_tokens", None):
usage_data["cacheReadInputTokens"] = cached
if creation := getattr(tokens_details, "cache_creation_tokens", None):
usage_data["cacheWriteInputTokens"] = creation

return StreamEvent(
metadata=MetadataEvent(
metrics={
"latencyMs": 0, # TODO
},
usage=usage_data,
)
)
# For all other cases, use the parent implementation
return super().format_chunk(event)

@override
async def stream(
self,
Expand All @@ -139,6 +246,7 @@ async def stream(
system_prompt: Optional[str] = None,
*,
tool_choice: ToolChoice | None = None,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LiteLLM model.
Expand All @@ -148,13 +256,16 @@ async def stream(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Formatted message chunks from the model.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
request = self.format_request(
messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content
)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down
89 changes: 75 additions & 14 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import BaseModel
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.content import ContentBlock, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
Expand Down Expand Up @@ -89,11 +89,12 @@ def get_config(self) -> OpenAIConfig:
return cast(OpenAIModel.OpenAIConfig, self.config)

@classmethod
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]:
"""Format an OpenAI compatible content block.

Args:
content: Message content.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
OpenAI compatible content block.
Expand Down Expand Up @@ -131,11 +132,12 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

@classmethod
def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]:
def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]:
"""Format an OpenAI compatible tool call.

Args:
tool_use: Tool use requested by the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
OpenAI compatible tool call.
Expand All @@ -150,11 +152,12 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]:
}

@classmethod
def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]:
"""Format an OpenAI compatible tool message.

Args:
tool_result: Tool result collected from a tool execution.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
OpenAI compatible tool message.
Expand Down Expand Up @@ -198,18 +201,43 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str
return {"tool_choice": "auto"}

@classmethod
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format an OpenAI compatible messages array.
def _format_system_messages(
cls,
system_prompt: Optional[str] = None,
*,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
) -> list[dict[str, Any]]:
"""Format system messages for OpenAI-compatible providers.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.
system_prompt_content: System prompt content blocks to provide context to the model.

Returns:
An OpenAI compatible messages array.
List of formatted system messages.
"""
# Handle backward compatibility: if system_prompt is provided but system_prompt_content is None
if system_prompt and system_prompt_content is None:
system_prompt_content = [{"text": system_prompt}]

# TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140
return [
{"role": "system", "content": content["text"]}
for content in system_prompt_content or []
if "text" in content
]

@classmethod
def _format_regular_messages(cls, messages: Messages) -> list[dict[str, Any]]:
"""Format regular messages for OpenAI-compatible providers.

Args:
messages: List of message objects to be processed by the model.

Returns:
List of formatted messages.
"""
formatted_messages: list[dict[str, Any]]
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
formatted_messages = []

for message in messages:
contents = message["content"]
Expand Down Expand Up @@ -242,14 +270,42 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str
formatted_messages.append(formatted_message)
formatted_messages.extend(formatted_tool_messages)

return formatted_messages

@classmethod
def format_request_messages(
cls,
messages: Messages,
system_prompt: Optional[str] = None,
*,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Format an OpenAI compatible messages array.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.
system_prompt_content: System prompt content blocks to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
An OpenAI compatible messages array.
"""
formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content)
formatted_messages.extend(cls._format_regular_messages(messages))

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

def format_request(
self,
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
tool_choice: ToolChoice | None = None,
*,
system_prompt_content: list[SystemContentBlock] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Format an OpenAI compatible chat streaming request.

Expand All @@ -258,6 +314,8 @@ def format_request(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
An OpenAI compatible chat streaming request.
Expand All @@ -267,7 +325,9 @@ def format_request(
format.
"""
return {
"messages": self.format_request_messages(messages, system_prompt),
"messages": self.format_request_messages(
messages, system_prompt, system_prompt_content=system_prompt_content
),
"model": self.config["model_id"],
"stream": True,
"stream_options": {"include_usage": True},
Expand All @@ -286,11 +346,12 @@ def format_request(
**cast(dict[str, Any], self.config.get("params", {})),
}

def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format an OpenAI response event into a standardized message chunk.

Args:
event: A response event from the OpenAI compatible model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
The formatted chunk.
Expand Down
8 changes: 6 additions & 2 deletions src/strands/models/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def format_request(
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Format an Amazon SageMaker chat streaming request.

Expand All @@ -211,6 +212,7 @@ def format_request(
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
interface consistency but is currently ignored for this model provider.**
**kwargs: Additional keyword arguments for future extensibility.

Returns:
An Amazon SageMaker chat streaming request.
Expand Down Expand Up @@ -501,11 +503,12 @@ async def stream(

@override
@classmethod
def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]:
"""Format a SageMaker compatible tool message.

Args:
tool_result: Tool result collected from a tool execution.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
SageMaker compatible tool message with content as a string.
Expand All @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:

@override
@classmethod
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]:
"""Format a content block.

Args:
content: Message content.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
Formatted content block.
Expand Down
Loading
Loading