diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..1887bb9d3 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -116,6 +116,9 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> if "image" in content: return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] + if "reasoningContent" in content: + return [] + if "toolUse" in content: return [ { @@ -237,13 +240,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: return {"messageStart": {"role": "assistant"}} case "content_start": - if event["data_type"] == "text": + if event["data_type"] == "text" or event["data_type"] == "reasoning_content": return {"contentBlockStart": {"start": {}}} tool_name = event["data"].function.name return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} case "content_delta": + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + if event["data_type"] == "text": return {"contentBlockDelta": {"delta": {"text": event["data"]}}} @@ -320,14 +326,29 @@ async def stream( yield self.format_chunk({"chunk_type": "message_start"}) yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + is_thinking = False async for event in response: + if event.message.thinking: + if not is_thinking: + is_thinking = True + yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": event.message.thinking} + ) + elif is_thinking: + is_thinking = False + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + for tool_call in event.message.tool_calls or []: yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call}) tool_requested = True - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": event.message.content}) + if event.message.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": event.message.content} + ) yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) yield self.format_chunk( diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 14db63a24..60b106997 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -418,6 +418,7 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" + mock_event.message.thinking = None mock_event.done_reason = "stop" mock_event.eval_count = 10 mock_event.prompt_eval_count = 5 @@ -457,6 +458,63 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings assert len(captured_warnings) == 0 +@pytest.mark.asyncio +async def test_stream_thinking(ollama_client, model, agenerator, alist, captured_warnings): + think_event = unittest.mock.Mock() + think_event.message.tool_calls = None + think_event.message.content = None + think_event.message.thinking = "t1" + think_event.done_reason = "stop" + think_event.eval_count = 10 + think_event.prompt_eval_count = 5 + think_event.total_duration = 1000000 # 1ms in nanoseconds + + text_event = unittest.mock.Mock() + text_event.message.tool_calls = None + text_event.message.content = "Hello" + text_event.message.thinking = None + text_event.done_reason = "stop" + text_event.eval_count = 10 + text_event.prompt_eval_count = 5 + text_event.total_duration = 1000000 # 1ms in nanoseconds + + ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([think_event, text_event])) + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "t1"}}}}, + {"contentBlockStop": {}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 1.0}, + } + }, + ] + + assert tru_events == exp_events + expected_request = { + "model": "m1", + "messages": [{"role": "user", "content": "Hello"}], + "options": {}, + "stream": True, + "tools": [], + } + ollama_client.chat.assert_called_once_with(**expected_request) + + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + @pytest.mark.asyncio async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator, alist, captured_warnings): """Test that non-None toolChoice emits warning for unsupported providers.""" @@ -465,6 +523,7 @@ async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator, mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" + mock_event.message.thinking = None mock_event.done_reason = "stop" mock_event.eval_count = 10 mock_event.prompt_eval_count = 5 @@ -487,6 +546,7 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): mock_tool_call.function.arguments = {"expression": "2+2"} mock_event.message.tool_calls = [mock_tool_call] mock_event.message.content = "I'll calculate that for you" + mock_event.message.thinking = None mock_event.done_reason = "stop" mock_event.eval_count = 15 mock_event.prompt_eval_count = 8 @@ -559,3 +619,20 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_chunk_content_block_delta_thinking_delta(model): + event = {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "t1"} + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "delta": { + "reasoningContent": { + "text": "t1", + }, + }, + }, + } + + assert tru_chunk == exp_chunk