|
20 | 20 | import java.util.Base64; |
21 | 21 | import java.util.List; |
22 | 22 |
|
| 23 | +import okhttp3.mockwebserver.MockResponse; |
| 24 | +import okhttp3.mockwebserver.MockWebServer; |
| 25 | +import okhttp3.mockwebserver.RecordedRequest; |
23 | 26 | import org.junit.jupiter.api.Disabled; |
24 | 27 | import org.junit.jupiter.api.Test; |
25 | 28 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; |
|
35 | 38 | import org.springframework.ai.openai.api.OpenAiApi.Embedding; |
36 | 39 | import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; |
37 | 40 | import org.springframework.core.io.ClassPathResource; |
| 41 | +import org.springframework.http.HttpHeaders; |
| 42 | +import org.springframework.http.MediaType; |
38 | 43 | import org.springframework.http.ResponseEntity; |
39 | 44 |
|
40 | 45 | import static org.assertj.core.api.Assertions.assertThat; |
@@ -237,4 +242,58 @@ void chatCompletionEntityWithServiceTier(OpenAiApi.ServiceTier serviceTier) { |
237 | 242 | assertThat(response.getBody().serviceTier()).containsIgnoringCase(serviceTier.getValue()); |
238 | 243 | } |
239 | 244 |
|
| 245 | + @Test |
| 246 | + void userAgentHeaderIsSentInChatCompletionRequests() throws Exception { |
| 247 | + try (MockWebServer mockWebServer = new MockWebServer()) { |
| 248 | + mockWebServer.start(); |
| 249 | + |
| 250 | + // Mock response from OpenAI |
| 251 | + mockWebServer.enqueue(new MockResponse().setResponseCode(200) |
| 252 | + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) |
| 253 | + .setBody(""" |
| 254 | + { |
| 255 | + "id": "chatcmpl-123", |
| 256 | + "object": "chat.completion", |
| 257 | + "created": 1677652288, |
| 258 | + "model": "gpt-3.5-turbo", |
| 259 | + "choices": [{ |
| 260 | + "index": 0, |
| 261 | + "message": { |
| 262 | + "role": "assistant", |
| 263 | + "content": "Hello there!" |
| 264 | + }, |
| 265 | + "finish_reason": "stop" |
| 266 | + }], |
| 267 | + "usage": { |
| 268 | + "prompt_tokens": 9, |
| 269 | + "completion_tokens": 2, |
| 270 | + "total_tokens": 11 |
| 271 | + } |
| 272 | + } |
| 273 | + """)); |
| 274 | + |
| 275 | + // Create OpenAiApi instance pointing to mock server |
| 276 | + OpenAiApi testApi = OpenAiApi.builder() |
| 277 | + .apiKey(System.getenv("OPENAI_API_KEY")) |
| 278 | + .baseUrl(mockWebServer.url("/").toString()) |
| 279 | + .build(); |
| 280 | + |
| 281 | + // Make a request |
| 282 | + ChatCompletionMessage message = new ChatCompletionMessage("Hello world", Role.USER); |
| 283 | + ResponseEntity<ChatCompletion> response = testApi |
| 284 | + .chatCompletionEntity(new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, false)); |
| 285 | + |
| 286 | + // Verify the response succeeded |
| 287 | + assertThat(response).isNotNull(); |
| 288 | + assertThat(response.getBody()).isNotNull(); |
| 289 | + |
| 290 | + // Verify the User-Agent header was sent in the request |
| 291 | + RecordedRequest recordedRequest = mockWebServer.takeRequest(); |
| 292 | + assertThat(recordedRequest.getHeader(OpenAiApi.HTTP_USER_AGENT_HEADER)) |
| 293 | + .isEqualTo(OpenAiApi.SPRING_AI_USER_AGENT); |
| 294 | + |
| 295 | + mockWebServer.shutdown(); |
| 296 | + } |
| 297 | + } |
| 298 | + |
240 | 299 | } |
0 commit comments