From 62c4738ec80071f73716df8d2a20330e4a7c79da Mon Sep 17 00:00:00 2001 From: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:08:02 +0100 Subject: [PATCH] Optimize MistralAiEmbeddingModel dimensions method - Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Fix warnings in MistralAiEmbeddingModelTests unit tests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com> --- .../ai/mistralai/MistralAiEmbeddingModel.java | 20 +++++++++++++------ .../MistralAiEmbeddingModelTests.java | 17 ++++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 8650fca10b7..ac327b1cdc4 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -16,6 +16,7 @@ package org.springframework.ai.mistralai; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -56,16 +57,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + /** * Known embedding dimensions for Mistral AI models. Maps model names to their * respective embedding vector dimensions. This allows the dimensions() method to * return the correct value without making an API call. */ - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of( - MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), - 1536); - - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final Map knownEmbeddingDimensions = createKnownEmbeddingDimensions(); private final MistralAiEmbeddingOptions defaultOptions; @@ -85,6 +84,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel { */ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + private static Map createKnownEmbeddingDimensions() { + Map knownEmbeddingDimensions = new HashMap<>(); + knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024); + knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), 1536); + + return knownEmbeddingDimensions; + } + @Deprecated public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) { this(mistralAiApi, MetadataMode.EMBED); @@ -197,7 +204,8 @@ public float[] embed(Document document) { @Override public int dimensions() { - return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); + return this.knownEmbeddingDimensions.computeIfAbsent(this.defaultOptions.getModel(), + model -> super.dimensions()); } /** diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java index 1ac71e276e0..19fbf377643 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java @@ -16,8 +16,10 @@ package org.springframework.ai.mistralai; +import java.util.Arrays; import java.util.List; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -46,7 +48,7 @@ void testDimensionsForMistralEmbedModel() { .build(); MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options, - RetryUtils.DEFAULT_RETRY_TEMPLATE); + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); assertThat(model.dimensions()).isEqualTo(1024); } @@ -60,7 +62,7 @@ void testDimensionsForCodestralEmbedModel() { .build(); MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options, - RetryUtils.DEFAULT_RETRY_TEMPLATE); + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); assertThat(model.dimensions()).isEqualTo(1536); } @@ -73,7 +75,7 @@ void testDimensionsFallbackForUnknownModel() { MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build(); MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options, - RetryUtils.DEFAULT_RETRY_TEMPLATE); + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); // Should fall back to super.dimensions() which detects dimensions from the API // response @@ -94,7 +96,7 @@ void testAllEmbeddingModelsHaveDimensionMapping() { .build(); MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options, - RetryUtils.DEFAULT_RETRY_TEMPLATE); + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); // Each model should have a valid dimension (not the fallback -1) assertThat(model.dimensions()).as("Model %s should have a dimension mapping", embeddingModel.getValue()) @@ -122,16 +124,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) { // Create a mock embedding response with the specified dimensions float[] embedding = new float[dimensions]; - for (int i = 0; i < dimensions; i++) { - embedding[i] = 0.1f; - } + Arrays.fill(embedding, 0.1f); MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding"); MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10); - MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData), - "model", usage); + var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage); when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));