diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java
index c5a99f99eb3..2cd86a267b8 100644
--- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java
+++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java
@@ -1077,13 +1077,31 @@ public record TopLogProbs(// @formatter:off
* @param promptTokens Number of tokens in the prompt.
* @param totalTokens Total number of tokens used in the request (prompt +
* completion).
+ * @param promptTokensDetails Details about the prompt tokens used. Support for
+ * GLM-4.5 and later models.
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record Usage(// @formatter:off
@JsonProperty("completion_tokens") Integer completionTokens,
@JsonProperty("prompt_tokens") Integer promptTokens,
- @JsonProperty("total_tokens") Integer totalTokens) { // @formatter:on
+ @JsonProperty("total_tokens") Integer totalTokens,
+ @JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails) { // @formatter:on
+
+ public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
+ this(completionTokens, promptTokens, totalTokens, null);
+ }
+
+ /**
+ * Details about the prompt tokens used.
+ *
+ * @param cachedTokens Number of tokens in the prompt that were cached.
+ */
+ @JsonInclude(Include.NON_NULL)
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public record PromptTokensDetails(// @formatter:off
+ @JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
+ }
}
diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java
index e06abdf5d3c..2c1bc90f705 100644
--- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java
+++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java
@@ -40,7 +40,6 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
-import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
@@ -85,13 +84,22 @@ class ZhiPuAiChatModelIT {
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
+ /**
+ * Default chat options to use for the tests.
+ *
+ * glm-4-flash is a free model, so it is used by default on the tests.
+ */
+ private static final ZhiPuAiChatOptions DEFAULT_CHAT_OPTIONS = ZhiPuAiChatOptions.builder()
+ .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
+ .build();
+
@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
- Prompt prompt = new Prompt(List.of(userMessage, systemMessage), ChatOptions.builder().build());
+ Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
@@ -104,7 +112,7 @@ void streamRoleTest() {
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
- Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
+ Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
Flux flux = this.streamingChatModel.stream(prompt);
List responses = flux.collectList().block();
@@ -135,7 +143,7 @@ void listOutputConverter() {
.template(template)
.variables(Map.of("subject", "ice cream flavors", "format", format))
.build();
- Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
+ Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();
List list = outputConverter.convert(generation.getOutput().getText());
@@ -157,8 +165,9 @@ void mapOutputConverter() {
.variables(Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format",
format))
.build();
- Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
- Generation generation = this.chatModel.call(prompt).getResult();
+ Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
+ ChatResponse chatResponse = this.chatModel.call(prompt);
+ Generation generation = chatResponse.getResult();
Map result = outputConverter.convert(generation.getOutput().getText());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -179,7 +188,7 @@ void beanOutputConverter() {
.template(template)
.variables(Map.of("format", format))
.build();
- Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
+ Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
@@ -198,7 +207,7 @@ void beanOutputConverterRecords() {
.template(template)
.variables(Map.of("format", format))
.build();
- Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
+ Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
@@ -221,7 +230,7 @@ void beanStreamOutputConverterRecords() {
.template(template)
.variables(Map.of("format", format))
.build();
- Prompt prompt = new Prompt(promptTemplate.createMessage());
+ Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
String generationTextFromStream = Objects
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
@@ -253,7 +262,10 @@ void jsonObjectResponseFormatOutputConverterRecords() {
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(),
- ZhiPuAiChatOptions.builder().responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()).build());
+ ZhiPuAiChatOptions.builder()
+ .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
+ .responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject())
+ .build());
String generationTextFromStream = Objects
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
@@ -281,7 +293,7 @@ void functionCallTest() {
List messages = new ArrayList<>(List.of(userMessage));
var promptOptions = ZhiPuAiChatOptions.builder()
- .model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
+ .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
@@ -306,7 +318,7 @@ void streamFunctionCallTest() {
List messages = new ArrayList<>(List.of(userMessage));
var promptOptions = ZhiPuAiChatOptions.builder()
- .model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
+ .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
@@ -332,8 +344,7 @@ void streamFunctionCallTest() {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4.5-flash" })
void enabledThinkingTest(String modelName) {
- UserMessage userMessage = new UserMessage(
- "Are there an infinite number of prime numbers such that n mod 4 == 3?");
+ UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");
var promptOptions = ZhiPuAiChatOptions.builder()
.model(modelName)
@@ -344,14 +355,16 @@ void enabledThinkingTest(String modelName) {
ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions));
logger.info("Response: {}", response);
- for (Generation generation : response.getResults()) {
- AssistantMessage message = generation.getOutput();
+ Generation generation = response.getResult();
+ AssistantMessage message = generation.getOutput();
- assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);
+ assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);
- assertThat(message.getText()).isNotBlank();
- assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();
- }
+ assertThat(message.getText()).isNotBlank();
+ assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();
+
+ ZhiPuAiApi.Usage nativeUsage = (ZhiPuAiApi.Usage) response.getMetadata().getUsage().getNativeUsage();
+ assertThat(nativeUsage.promptTokensDetails()).isNotNull();
}
@ParameterizedTest(name = "{0} : {displayName} ")
@@ -382,8 +395,7 @@ void disabledThinkingTest(String modelName) {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4.5-flash" })
void streamAndEnableThinkingTest(String modelName) {
- UserMessage userMessage = new UserMessage(
- "Are there an infinite number of prime numbers such that n mod 4 == 3?");
+ UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");
var promptOptions = ZhiPuAiChatOptions.builder()
.model(modelName)
@@ -408,6 +420,7 @@ void streamAndEnableThinkingTest(String modelName) {
}
return message.getText();
})
+ .filter(StringUtils::hasText)
.collect(Collectors.joining());
logger.info("reasoningContent: {}", reasoningContent);
@@ -420,7 +433,7 @@ void streamAndEnableThinkingTest(String modelName) {
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { "glm-4v" })
+ @ValueSource(strings = { "glm-4v-flash" })
void multiModalityEmbeddedImage(String modelName) throws IOException {
var imageData = new ClassPathResource("/test.png");
@@ -461,7 +474,7 @@ void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IO
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" })
+ @ValueSource(strings = { "glm-4v-flash", "glm-4.1v-thinking-flash" })
void multiModalityImageUrl(String modelName) throws IOException {
var userMessage = UserMessage.builder()
@@ -505,7 +518,7 @@ void reasonerMultiModalityImageUrl(String modelName) throws IOException {
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { "glm-4v" })
+ @ValueSource(strings = { "glm-4v-flash" })
void streamingMultiModalityImageUrl(String modelName) throws IOException {
var userMessage = UserMessage.builder()