Skip to content

Commit 7320cb0

Browse files
committed
Optimize MistralAiEmbeddingModel dimensions method
- Calculate default value only if necessary - Fix warnings in MistralAiEmbeddingModelTests unit tests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 4532f64 commit 7320cb0

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,13 @@ public float[] embed(Document document) {
197197

198198
@Override
199199
public int dimensions() {
200-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
200+
var dimensions = KNOWN_EMBEDDING_DIMENSIONS.get(this.defaultOptions.getModel());
201+
202+
if (dimensions == null) {
203+
dimensions = super.dimensions();
204+
}
205+
206+
return dimensions;
201207
}
202208

203209
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.Arrays;
1920
import java.util.List;
2021

22+
import io.micrometer.observation.ObservationRegistry;
2123
import org.junit.jupiter.api.Test;
2224
import org.mockito.Mockito;
2325

@@ -46,7 +48,7 @@ void testDimensionsForMistralEmbedModel() {
4648
.build();
4749

4850
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
49-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
51+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
5052

5153
assertThat(model.dimensions()).isEqualTo(1024);
5254
}
@@ -60,7 +62,7 @@ void testDimensionsForCodestralEmbedModel() {
6062
.build();
6163

6264
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
63-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
65+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
6466

6567
assertThat(model.dimensions()).isEqualTo(1536);
6668
}
@@ -73,7 +75,7 @@ void testDimensionsFallbackForUnknownModel() {
7375
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build();
7476

7577
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
76-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
78+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
7779

7880
// Should fall back to super.dimensions() which detects dimensions from the API
7981
// response
@@ -94,7 +96,7 @@ void testAllEmbeddingModelsHaveDimensionMapping() {
9496
.build();
9597

9698
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
97-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
99+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
98100

99101
// Each model should have a valid dimension (not the fallback -1)
100102
assertThat(model.dimensions()).as("Model %s should have a dimension mapping", embeddingModel.getValue())
@@ -122,16 +124,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
122124

123125
// Create a mock embedding response with the specified dimensions
124126
float[] embedding = new float[dimensions];
125-
for (int i = 0; i < dimensions; i++) {
126-
embedding[i] = 0.1f;
127-
}
127+
Arrays.fill(embedding, 0.1f);
128128

129129
MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding");
130130

131131
MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10);
132132

133-
MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData),
134-
"model", usage);
133+
var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage);
135134

136135
when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));
137136

0 commit comments

Comments
 (0)