Skip to content

Commit 87917c7

Browse files
committed
Make known embedding dimensions maps immutable
Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 4532f64 commit 87917c7

File tree

5 files changed

+41
-23
lines changed

5 files changed

+41
-23
lines changed

models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import java.util.Map;
22-
import java.util.stream.Collectors;
23-
import java.util.stream.Stream;
2422

2523
import com.google.genai.Client;
2624
import com.google.genai.types.ContentEmbedding;
@@ -43,6 +41,7 @@
4341
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
4442
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
4543
import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails;
44+
import org.springframework.ai.model.EmbeddingModelDescription;
4645
import org.springframework.ai.model.ModelOptionsUtils;
4746
import org.springframework.ai.observation.conventions.AiProvider;
4847
import org.springframework.ai.retry.RetryUtils;
@@ -64,10 +63,8 @@ public class GoogleGenAiTextEmbeddingModel extends AbstractEmbeddingModel {
6463

6564
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
6665

67-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
68-
.of(GoogleGenAiTextEmbeddingModelName.values())
69-
.collect(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName,
70-
GoogleGenAiTextEmbeddingModelName::getDimensions));
66+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
67+
.calculateKnownEmbeddingDimensions(GoogleGenAiTextEmbeddingModelName.class);
7168

7269
public final GoogleGenAiTextEmbeddingOptions defaultOptions;
7370

@@ -259,7 +256,13 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
259256

260257
@Override
261258
public int dimensions() {
262-
return KNOWN_EMBEDDING_DIMENSIONS.computeIfAbsent(this.defaultOptions.getModel(), model -> super.dimensions());
259+
var dimensions = KNOWN_EMBEDDING_DIMENSIONS.get(this.defaultOptions.getModel());
260+
261+
if (dimensions == null) {
262+
dimensions = super.dimensions();
263+
}
264+
265+
return dimensions;
263266
}
264267

265268
/**

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,8 +20,6 @@
2020
import java.util.EnumMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.Stream;
2523

2624
import com.google.cloud.aiplatform.v1.EndpointName;
2725
import com.google.cloud.aiplatform.v1.PredictRequest;
@@ -43,6 +41,7 @@
4341
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
4442
import org.springframework.ai.embedding.EmbeddingResultMetadata;
4543
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
44+
import org.springframework.ai.model.EmbeddingModelDescription;
4645
import org.springframework.ai.model.ModelOptionsUtils;
4746
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
4847
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
@@ -75,10 +74,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
7574
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
7675
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));
7776

78-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
79-
.of(VertexAiMultimodalEmbeddingModelName.values())
80-
.collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName,
81-
VertexAiMultimodalEmbeddingModelName::getDimensions));
77+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
78+
.calculateKnownEmbeddingDimensions(VertexAiMultimodalEmbeddingModelName.class);
8279

8380
public final VertexAiMultimodalEmbeddingOptions defaultOptions;
8481

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.Stream;
2523

2624
import com.google.cloud.aiplatform.v1.EndpointName;
2725
import com.google.cloud.aiplatform.v1.PredictRequest;
@@ -43,6 +41,7 @@
4341
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
4442
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
4543
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
44+
import org.springframework.ai.model.EmbeddingModelDescription;
4645
import org.springframework.ai.model.ModelOptionsUtils;
4746
import org.springframework.ai.observation.conventions.AiProvider;
4847
import org.springframework.ai.retry.RetryUtils;
@@ -67,10 +66,8 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
6766

6867
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
6968

70-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
71-
.of(VertexAiTextEmbeddingModelName.values())
72-
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
73-
VertexAiTextEmbeddingModelName::getDimensions));
69+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
70+
.calculateKnownEmbeddingDimensions(VertexAiTextEmbeddingModelName.class);
7471

7572
public final VertexAiTextEmbeddingOptions defaultOptions;
7673

@@ -244,7 +241,13 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
244241

245242
@Override
246243
public int dimensions() {
247-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
244+
var dimensions = KNOWN_EMBEDDING_DIMENSIONS.get(this.defaultOptions.getModel());
245+
246+
if (dimensions == null) {
247+
dimensions = super.dimensions();
248+
}
249+
250+
return dimensions;
248251
}
249252

250253
/**

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ private static Map<String, Integer> loadKnownModelDimensions() {
8282
}
8383
return properties.entrySet()
8484
.stream()
85-
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
85+
.collect(Collectors.collectingAndThen(
86+
Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())),
87+
Map::copyOf));
8688
}
8789
catch (IOException e) {
8890
throw new RuntimeException(e);

spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,15 +16,28 @@
1616

1717
package org.springframework.ai.model;
1818

19+
import java.util.Map;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
22+
1923
/**
2024
* Description of an embedding model.
2125
*
2226
* @author Christian Tzolov
27+
* @author Nicolas Krier
2328
*/
2429
public interface EmbeddingModelDescription extends ModelDescription {
2530

2631
default int getDimensions() {
2732
return -1;
2833
}
2934

35+
static <E extends Enum<E> & EmbeddingModelDescription> Map<String, Integer> calculateKnownEmbeddingDimensions(
36+
Class<E> embeddingModelClass) {
37+
return Stream.of(embeddingModelClass.getEnumConstants())
38+
.collect(Collectors.collectingAndThen(
39+
Collectors.toMap(ModelDescription::getName, EmbeddingModelDescription::getDimensions),
40+
Map::copyOf));
41+
}
42+
3043
}

0 commit comments

Comments
 (0)