Skip to content

Commit 86eb8ad

Browse files
committed
Make known embedding dimensions maps immutable
Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 4532f64 commit 86eb8ad

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ public class GoogleGenAiTextEmbeddingModel extends AbstractEmbeddingModel {
6666

6767
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
6868
.of(GoogleGenAiTextEmbeddingModelName.values())
69-
.collect(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName,
70-
GoogleGenAiTextEmbeddingModelName::getDimensions));
69+
.collect(Collectors.collectingAndThen(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName,
70+
GoogleGenAiTextEmbeddingModelName::getDimensions), Map::copyOf));
7171

7272
public final GoogleGenAiTextEmbeddingOptions defaultOptions;
7373

@@ -259,7 +259,13 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
259259

260260
@Override
261261
public int dimensions() {
262-
return KNOWN_EMBEDDING_DIMENSIONS.computeIfAbsent(this.defaultOptions.getModel(), model -> super.dimensions());
262+
var dimensions = KNOWN_EMBEDDING_DIMENSIONS.get(this.defaultOptions.getModel());
263+
264+
if (dimensions == null) {
265+
dimensions = super.dimensions();
266+
}
267+
268+
return dimensions;
263269
}
264270

265271
/**

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
7777

7878
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
7979
.of(VertexAiMultimodalEmbeddingModelName.values())
80-
.collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName,
81-
VertexAiMultimodalEmbeddingModelName::getDimensions));
80+
.collect(Collectors.collectingAndThen(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName,
81+
VertexAiMultimodalEmbeddingModelName::getDimensions), Map::copyOf));
8282

8383
public final VertexAiMultimodalEmbeddingOptions defaultOptions;
8484

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
6969

7070
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
7171
.of(VertexAiTextEmbeddingModelName.values())
72-
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
73-
VertexAiTextEmbeddingModelName::getDimensions));
72+
.collect(Collectors.collectingAndThen(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
73+
VertexAiTextEmbeddingModelName::getDimensions), Map::copyOf));
7474

7575
public final VertexAiTextEmbeddingOptions defaultOptions;
7676

@@ -244,7 +244,13 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
244244

245245
@Override
246246
public int dimensions() {
247-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
247+
var dimensions = KNOWN_EMBEDDING_DIMENSIONS.get(this.defaultOptions.getModel());
248+
249+
if (dimensions == null) {
250+
dimensions = super.dimensions();
251+
}
252+
253+
return dimensions;
248254
}
249255

250256
/**

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ private static Map<String, Integer> loadKnownModelDimensions() {
8282
}
8383
return properties.entrySet()
8484
.stream()
85-
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
85+
.collect(Collectors.collectingAndThen(
86+
Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())),
87+
Map::copyOf));
8688
}
8789
catch (IOException e) {
8890
throw new RuntimeException(e);

0 commit comments

Comments
 (0)