Skip to content

Commit dd34dda

Browse files
committed
Refactor overridden EmbeddingModel dimensions method
- Provide a way to override known embedding dimensions - Make known embedding dimensions maps immutable Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 4532f64 commit dd34dda

File tree

5 files changed

+73
-25
lines changed

5 files changed

+73
-25
lines changed

models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import java.util.Map;
22-
import java.util.stream.Collectors;
23-
import java.util.stream.Stream;
22+
import java.util.Objects;
2423

2524
import com.google.genai.Client;
2625
import com.google.genai.types.ContentEmbedding;
@@ -43,6 +42,7 @@
4342
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
4443
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
4544
import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails;
45+
import org.springframework.ai.model.EmbeddingModelDescription;
4646
import org.springframework.ai.model.ModelOptionsUtils;
4747
import org.springframework.ai.observation.conventions.AiProvider;
4848
import org.springframework.ai.retry.RetryUtils;
@@ -64,10 +64,8 @@ public class GoogleGenAiTextEmbeddingModel extends AbstractEmbeddingModel {
6464

6565
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
6666

67-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
68-
.of(GoogleGenAiTextEmbeddingModelName.values())
69-
.collect(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName,
70-
GoogleGenAiTextEmbeddingModelName::getDimensions));
67+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
68+
.calculateKnownEmbeddingDimensions(GoogleGenAiTextEmbeddingModelName.class);
7169

7270
public final GoogleGenAiTextEmbeddingOptions defaultOptions;
7371

@@ -257,9 +255,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
257255
return new DefaultUsage(0, 0, totalTokens);
258256
}
259257

258+
@Override
259+
public Map<String, Integer> knownEmbeddingDimensions() {
260+
return KNOWN_EMBEDDING_DIMENSIONS;
261+
}
262+
260263
@Override
261264
public int dimensions() {
262-
return KNOWN_EMBEDDING_DIMENSIONS.computeIfAbsent(this.defaultOptions.getModel(), model -> super.dimensions());
265+
return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel()));
263266
}
264267

265268
/**

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,8 +20,6 @@
2020
import java.util.EnumMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.Stream;
2523

2624
import com.google.cloud.aiplatform.v1.EndpointName;
2725
import com.google.cloud.aiplatform.v1.PredictRequest;
@@ -43,6 +41,7 @@
4341
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
4442
import org.springframework.ai.embedding.EmbeddingResultMetadata;
4543
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
44+
import org.springframework.ai.model.EmbeddingModelDescription;
4645
import org.springframework.ai.model.ModelOptionsUtils;
4746
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
4847
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
@@ -75,10 +74,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
7574
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
7675
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));
7776

78-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
79-
.of(VertexAiMultimodalEmbeddingModelName.values())
80-
.collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName,
81-
VertexAiMultimodalEmbeddingModelName::getDimensions));
77+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
78+
.calculateKnownEmbeddingDimensions(VertexAiMultimodalEmbeddingModelName.class);
8279

8380
public final VertexAiMultimodalEmbeddingOptions defaultOptions;
8481

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.Stream;
23+
import java.util.Objects;
2524

2625
import com.google.cloud.aiplatform.v1.EndpointName;
2726
import com.google.cloud.aiplatform.v1.PredictRequest;
@@ -43,6 +42,7 @@
4342
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
4443
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
4544
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
45+
import org.springframework.ai.model.EmbeddingModelDescription;
4646
import org.springframework.ai.model.ModelOptionsUtils;
4747
import org.springframework.ai.observation.conventions.AiProvider;
4848
import org.springframework.ai.retry.RetryUtils;
@@ -67,10 +67,8 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
6767

6868
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
6969

70-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
71-
.of(VertexAiTextEmbeddingModelName.values())
72-
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
73-
VertexAiTextEmbeddingModelName::getDimensions));
70+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
71+
.calculateKnownEmbeddingDimensions(VertexAiTextEmbeddingModelName.class);
7472

7573
public final VertexAiTextEmbeddingOptions defaultOptions;
7674

@@ -242,9 +240,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
242240
return new DefaultUsage(0, 0, totalTokens);
243241
}
244242

243+
@Override
244+
public Map<String, Integer> knownEmbeddingDimensions() {
245+
return KNOWN_EMBEDDING_DIMENSIONS;
246+
}
247+
245248
@Override
246249
public int dimensions() {
247-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
250+
return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel()));
248251
}
249252

250253
/**

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,24 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St
7171
}
7272
}
7373

74+
/**
75+
* Return the dimension of the requested embedding generative name. Uses the embedding
76+
* model to retrieve its default dimensions if the generative name is unknown.
77+
* @param embeddingModel Embedding model client to determine its known embedding
78+
* dimensions and default dimensions.
79+
* @param modelName Embedding generative name to retrieve the dimensions for.
80+
* @return Returns the embedding dimensions for the model name.
81+
*/
82+
public static int dimensions(AbstractEmbeddingModel embeddingModel, String modelName) {
83+
var dimensions = embeddingModel.knownEmbeddingDimensions().get(modelName);
84+
85+
if (dimensions == null) {
86+
dimensions = embeddingModel.defaultDimensions();
87+
}
88+
89+
return dimensions;
90+
}
91+
7492
private static Map<String, Integer> loadKnownModelDimensions() {
7593
try {
7694
var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES;
@@ -82,21 +100,35 @@ private static Map<String, Integer> loadKnownModelDimensions() {
82100
}
83101
return properties.entrySet()
84102
.stream()
85-
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
103+
.collect(Collectors.collectingAndThen(
104+
Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())),
105+
Map::copyOf));
86106
}
87107
catch (IOException e) {
88108
throw new RuntimeException(e);
89109
}
90110
}
91111

92-
@Override
93-
public int dimensions() {
112+
private int defaultDimensions() {
94113
if (this.embeddingDimensions.get() < 0) {
95114
this.embeddingDimensions.set(dimensions(this, "Test", "Hello World"));
96115
}
97116
return this.embeddingDimensions.get();
98117
}
99118

119+
/**
120+
* Retrieve all the known embedding dimensions.
121+
* @return The map containing the known embedding dimensions by model name
122+
*/
123+
public Map<String, Integer> knownEmbeddingDimensions() {
124+
return KNOWN_EMBEDDING_DIMENSIONS;
125+
}
126+
127+
@Override
128+
public int dimensions() {
129+
return defaultDimensions();
130+
}
131+
100132
static class Hints implements RuntimeHintsRegistrar {
101133

102134
@Override

spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,15 +16,28 @@
1616

1717
package org.springframework.ai.model;
1818

19+
import java.util.Map;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
22+
1923
/**
2024
* Description of an embedding model.
2125
*
2226
* @author Christian Tzolov
27+
* @author Nicolas Krier
2328
*/
2429
public interface EmbeddingModelDescription extends ModelDescription {
2530

2631
default int getDimensions() {
2732
return -1;
2833
}
2934

35+
static <E extends Enum<E> & EmbeddingModelDescription> Map<String, Integer> calculateKnownEmbeddingDimensions(
36+
Class<E> embeddingModelClass) {
37+
return Stream.of(embeddingModelClass.getEnumConstants())
38+
.collect(Collectors.collectingAndThen(
39+
Collectors.toMap(ModelDescription::getName, EmbeddingModelDescription::getDimensions),
40+
Map::copyOf));
41+
}
42+
3043
}

0 commit comments

Comments
 (0)