From 259cffac9eb5971a5d030f7b06fb729aed87e4ad Mon Sep 17 00:00:00 2001 From: Kuntal Maity Date: Mon, 27 Oct 2025 22:22:18 +0530 Subject: [PATCH 1/2] image: Consolidate ImageOptions and ImageResponse API, add ImageResponseFormat (#326) Signed-off-by: Kuntal Maity --- .../azure/openai/AzureOpenAiImageModel.java | 2 +- .../azure/openai/AzureOpenAiImageOptions.java | 9 +- .../ai/openai/OpenAiImageOptions.java | 9 +- .../ai/openai/OpenAiImageOptionsTests.java | 16 ++-- .../image/OpenAiImageModelObservationIT.java | 3 +- .../api/StabilityAiImageOptions.java | 9 +- .../StabilityAiImageOptionsTests.java | 9 +- .../ai/zhipuai/ZhiPuAiImageOptions.java | 3 +- .../modules/ROOT/pages/api/imageclient.adoc | 6 +- .../org/springframework/ai/image/Image.java | 17 ++++ .../ai/image/ImageOptions.java | 2 +- .../ai/image/ImageOptionsBuilder.java | 13 ++- .../ai/image/ImageResponse.java | 19 +++++ .../ai/image/ImageResponseFormat.java | 69 ++++++++++++++++ ...efaultImageModelObservationConvention.java | 6 +- .../ai/image/ImageResponseTests.java | 82 +++++++++++++++++++ ...tImageModelObservationConventionTests.java | 5 +- 17 files changed, 243 insertions(+), 36 deletions(-) create mode 100644 spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java create mode 100644 spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index 88fe6ae6e51..3d73ea79074 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -181,7 +181,7 @@ private ImageGenerationOptions toOpenAiImageOptions(ImagePrompt prompt) { if (runtimeImageOptions.getResponseFormat() != null) { // b64_json or url imageGenerationOptions.setResponseFormat( - ImageGenerationResponseFormat.fromString(runtimeImageOptions.getResponseFormat())); + ImageGenerationResponseFormat.fromString(runtimeImageOptions.getResponseFormat().getValue())); } if (runtimeImageOptions.getWidth() != null && runtimeImageOptions.getHeight() != null) { imageGenerationOptions.setSize( diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index 10e0d13f47e..d8d818cd61a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; /** * The configuration information for a image generation request. @@ -81,7 +82,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { * b64_json. */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for @@ -149,11 +150,11 @@ public void setHeight(Integer height) { } @Override - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } @@ -279,7 +280,7 @@ public Builder deploymentName(String deploymentName) { return this; } - public Builder responseFormat(String responseFormat) { + public Builder responseFormat(ImageResponseFormat responseFormat) { this.options.setResponseFormat(responseFormat); return this; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java index 3b294a5b02b..73adc3e1107 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; /** * OpenAI Image API options. OpenAiImageOptions.java @@ -79,7 +80,7 @@ public class OpenAiImageOptions implements ImageOptions { * b64_json. */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for @@ -158,11 +159,11 @@ public void setQuality(String quality) { } @Override - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } @@ -326,7 +327,7 @@ public Builder quality(String quality) { return this; } - public Builder responseFormat(String responseFormat) { + public Builder responseFormat(ImageResponseFormat responseFormat) { this.options.setResponseFormat(responseFormat); return this; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java index faa266ebbeb..1e9baf5cccd 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java @@ -18,6 +18,8 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageResponseFormat; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -34,7 +36,7 @@ void testBuilderWithAllFields() { .N(2) .model("dall-e-3") .quality("hd") - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .width(1024) .height(1024) .style("vivid") @@ -44,7 +46,7 @@ void testBuilderWithAllFields() { assertThat(options.getN()).isEqualTo(2); assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); - assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); @@ -58,7 +60,7 @@ void testCopy() { .N(3) .model("dall-e-3") .quality("standard") - .responseFormat("b64_json") + .responseFormat(ImageResponseFormat.B64_JSON) .width(1792) .height(1024) .style("natural") @@ -99,7 +101,7 @@ void testSetters() { options.setN(4); options.setModel("dall-e-2"); options.setQuality("standard"); - options.setResponseFormat("url"); + options.setResponseFormat(ImageResponseFormat.URL); options.setWidth(512); options.setHeight(512); options.setStyle("vivid"); @@ -108,7 +110,7 @@ void testSetters() { assertThat(options.getN()).isEqualTo(4); assertThat(options.getModel()).isEqualTo("dall-e-2"); assertThat(options.getQuality()).isEqualTo("standard"); - assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(512); assertThat(options.getHeight()).isEqualTo(512); assertThat(options.getSize()).isEqualTo("512x512"); @@ -212,7 +214,7 @@ void testFluentApiPattern() { .N(1) .model("dall-e-3") .quality("hd") - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .width(1024) .height(1024) .style("vivid") @@ -222,7 +224,7 @@ void testFluentApiPattern() { assertThat(options.getN()).isEqualTo(1); assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); - assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java index 37dc7abcdba..e9244918529 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java @@ -23,6 +23,7 @@ import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.observation.conventions.AiOperationType; @@ -61,7 +62,7 @@ void observationForImageOperation() { .model(OpenAiImageApi.ImageModel.DALL_E_3.getValue()) .height(1024) .width(1024) - .responseFormat("url") + .responseFormat(ImageResponseFormat.URL) .style("natural") .build(); diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index cc4e29b6a1c..b8bf336bad0 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.stabilityai.StyleEnum; /** @@ -122,7 +123,7 @@ public class StabilityAiImageOptions implements ImageOptions { * accept header. Must be "application/json" or "image/png" */ @JsonProperty("response_format") - private String responseFormat; + private ImageResponseFormat responseFormat; /** * The strictness level of the diffusion process adherence to the prompt text. @@ -328,11 +329,11 @@ public void setHeight(Integer height) { } @Override - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } @@ -455,7 +456,7 @@ public Builder height(Integer height) { return this; } - public Builder responseFormat(String responseFormat) { + public Builder responseFormat(ImageResponseFormat responseFormat) { this.options.setResponseFormat(responseFormat); return this; } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java index 854f022f2f9..5f627f5cbe9 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; @@ -37,7 +38,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { .model("default-model") .width(512) .height(512) - .responseFormat("image/png") + .responseFormat(ImageResponseFormat.IMAGE_PNG) .cfgScale(7.0f) .clipGuidancePreset("FAST_BLUE") .sampler("DDIM") @@ -52,7 +53,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { .model("runtime-model") .width(1024) .height(768) - .responseFormat("application/json") + .responseFormat(ImageResponseFormat.APPLICATION_JSON) .cfgScale(14.0f) .clipGuidancePreset("FAST_GREEN") .sampler("DDPM") @@ -71,7 +72,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { assertThat(options.getModel()).isEqualTo("runtime-model"); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(768); - assertThat(options.getResponseFormat()).isEqualTo("application/json"); + assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.APPLICATION_JSON); assertThat(options.getCfgScale()).isEqualTo(14.0f); assertThat(options.getClipGuidancePreset()).isEqualTo("FAST_GREEN"); assertThat(options.getSampler()).isEqualTo("DDPM"); @@ -136,7 +137,7 @@ public Integer getHeight() { } @Override - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return null; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index 116b7293830..53a3f65d3ab 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; /** @@ -96,7 +97,7 @@ public Integer getHeight() { @Override @JsonIgnore - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return null; } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc index bd8d3a001bd..fdc63875880 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc @@ -89,7 +89,7 @@ public interface ImageOptions extends ModelOptions { Integer getHeight(); - String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 +ImageResponseFormat getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 } ---- @@ -112,6 +112,10 @@ public class ImageResponse implements ModelResponse { private final List imageGenerations; + Optional getResultAsBytes(); + + List getResultsAsBytes(); + @Override public ImageGeneration getResult() { // get the first result diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java index bf1f683a16a..775b6c26f61 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/Image.java @@ -16,7 +16,13 @@ package org.springframework.ai.image; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Base64; import java.util.Objects; +import java.util.Optional; + +import org.springframework.util.StringUtils; public class Image { @@ -72,4 +78,15 @@ public int hashCode() { return Objects.hash(this.url, this.b64Json); } + public Optional getB64JsonAsBytes() { + if (!StringUtils.hasText(this.b64Json)) { + return Optional.empty(); + } + return Optional.of(Base64.getDecoder().decode(this.b64Json)); + } + + public Optional getB64JsonAsInputStream() { + return getB64JsonAsBytes().map(ByteArrayInputStream::new); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java index 435f6fc62df..1e8967d16e8 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -38,7 +38,7 @@ public interface ImageOptions extends ModelOptions { Integer getHeight(); @Nullable - String getResponseFormat(); + ImageResponseFormat getResponseFormat(); @Nullable String getStyle(); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index 693a4f00f9d..dbdaa8ade32 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -38,11 +38,16 @@ public ImageOptionsBuilder model(String model) { return this; } - public ImageOptionsBuilder responseFormat(String responseFormat) { + public ImageOptionsBuilder responseFormat(ImageResponseFormat responseFormat) { this.options.setResponseFormat(responseFormat); return this; } + public ImageOptionsBuilder responseFormat(String responseFormat) { + this.options.setResponseFormat(ImageResponseFormat.fromValue(responseFormat)); + return this; + } + public ImageOptionsBuilder width(Integer width) { this.options.setWidth(width); return this; @@ -72,7 +77,7 @@ private static class DefaultImageModelOptions implements ImageOptions { private Integer height; - private String responseFormat; + private ImageResponseFormat responseFormat; private String style; @@ -95,11 +100,11 @@ public void setModel(String model) { } @Override - public String getResponseFormat() { + public ImageResponseFormat getResponseFormat() { return this.responseFormat; } - public void setResponseFormat(String responseFormat) { + public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java index c4605d81890..45d9ed20c8f 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; import org.springframework.ai.model.ModelResponse; import org.springframework.util.CollectionUtils; @@ -91,6 +92,24 @@ public ImageResponseMetadata getMetadata() { return this.imageResponseMetadata; } + public Optional getResultAsBytes() { + ImageGeneration firstGeneration = getResult(); + if (firstGeneration == null || firstGeneration.getOutput() == null) { + return Optional.empty(); + } + return firstGeneration.getOutput().getB64JsonAsBytes().map(byte[]::clone); + } + + public List getResultsAsBytes() { + return this.imageGenerations.stream() + .map(ImageGeneration::getOutput) + .filter(Objects::nonNull) + .map(Image::getB64JsonAsBytes) + .flatMap(Optional::stream) + .map(byte[]::clone) + .toList(); + } + @Override public String toString() { return "ImageResponse [" + "imageResponseMetadata=" + this.imageResponseMetadata + ", imageGenerations=" diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java new file mode 100644 index 00000000000..d7dd260f509 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageResponseFormat.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.util.Arrays; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import org.springframework.util.StringUtils; + +/** + * Common response formats supported by image generation providers. + * + * @author Kuntal Maity + */ +public enum ImageResponseFormat { + + URL("url"), + + B64_JSON("b64_json"), + + /** + * PNG responses typically returned by providers when requesting raw image bytes. + */ + IMAGE_PNG("image/png"), + + /** + * JSON responses containing additional metadata or base64 encoded payloads. + */ + APPLICATION_JSON("application/json"); + + private final String value; + + ImageResponseFormat(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return this.value; + } + + @JsonCreator + public static ImageResponseFormat fromValue(String value) { + if (!StringUtils.hasText(value)) { + return null; + } + return Arrays.stream(values()) + .filter(format -> format.value.equalsIgnoreCase(value)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Unsupported image response format: " + value)); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java index 7a2a1e86a3b..12501ae76b7 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java @@ -19,6 +19,7 @@ import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.util.StringUtils; /** @@ -84,10 +85,11 @@ public KeyValues getHighCardinalityKeyValues(ImageModelObservationContext contex // Request protected KeyValues requestImageFormat(KeyValues keyValues, ImageModelObservationContext context) { - if (StringUtils.hasText(context.getRequest().getOptions().getResponseFormat())) { + ImageResponseFormat responseFormat = context.getRequest().getOptions().getResponseFormat(); + if (responseFormat != null) { return keyValues.and( ImageModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_IMAGE_RESPONSE_FORMAT.asString(), - context.getRequest().getOptions().getResponseFormat()); + responseFormat.getValue()); } return keyValues; } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java b/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java new file mode 100644 index 00000000000..c71551d5bdc --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/image/ImageResponseTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class ImageResponseTests { + + @Test + void getResultAsBytesReturnsFirstDecodedImage() { + byte[] payload = "hello".getBytes(StandardCharsets.UTF_8); + String base64 = Base64.getEncoder().encodeToString(payload); + Image image = new Image("https://example.test/image.png", base64); + ImageResponse response = new ImageResponse(List.of(new ImageGeneration(image))); + + assertThat(response.getResultAsBytes()).hasValueSatisfying(bytes -> assertThat(bytes).isEqualTo(payload)); + } + + @Test + void getResultsAsBytesSkipsEntriesWithoutPayload() { + byte[] payload = new byte[] { 1, 2, 3 }; + String base64 = Base64.getEncoder().encodeToString(payload); + Image imageWithPayload = new Image(null, base64); + Image imageWithoutPayload = new Image("https://example.test/image.png", null); + ImageResponse response = new ImageResponse( + List.of(new ImageGeneration(imageWithPayload), new ImageGeneration(imageWithoutPayload))); + + assertThat(response.getResultsAsBytes()).hasSize(1) + .first() + .satisfies(bytes -> assertThat(bytes).isEqualTo(payload)); + } + + @Test + void imageProvidesOptionalStreamForBase64Payload() throws IOException { + byte[] payload = { 42, 43, 44 }; + String base64 = Base64.getEncoder().encodeToString(payload); + Image image = new Image(null, base64); + + assertThat(image.getB64JsonAsBytes()).contains(payload); + assertThat(image.getB64JsonAsInputStream()).hasValueSatisfying(stream -> { + try (stream) { + assertThat(stream.readAllBytes()).isEqualTo(payload); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + } + + @Test + void helpersReturnEmptyWhenPayloadMissing() { + Image image = new Image("https://example.test/image.png", null); + ImageResponse response = new ImageResponse(List.of(new ImageGeneration(image))); + + assertThat(image.getB64JsonAsBytes()).isEmpty(); + assertThat(image.getB64JsonAsInputStream()).isEmpty(); + assertThat(response.getResultAsBytes()).isEmpty(); + assertThat(response.getResultsAsBytes()).isEmpty(); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java index a0c2c4a8305..b3a50e19206 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java @@ -23,6 +23,7 @@ import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.observation.conventions.AiObservationAttributes; import static org.assertj.core.api.Assertions.assertThat; @@ -90,7 +91,7 @@ void shouldHaveHighCardinalityKeyValuesWhenDefined() { .height(1080) .width(1920) .style("sketch") - .responseFormat("base64") + .responseFormat(ImageResponseFormat.B64_JSON) .build(); ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt(imageOptions)) @@ -98,7 +99,7 @@ void shouldHaveHighCardinalityKeyValuesWhenDefined() { .build(); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( - KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_RESPONSE_FORMAT.value(), "base64"), + KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_RESPONSE_FORMAT.value(), "b64_json"), KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_SIZE.value(), "1920x1080"), KeyValue.of(AiObservationAttributes.REQUEST_IMAGE_STYLE.value(), "sketch")); } From 2d607148688e8bfcf88fa47ad41b061b7cda189a Mon Sep 17 00:00:00 2001 From: Kuntal Maity Date: Mon, 27 Oct 2025 22:45:59 +0530 Subject: [PATCH 2/2] image: Consolidate ImageOptions and ImageResponse API, add ImageResponseFormat (#326) Signed-off-by: Kuntal Maity --- .../ai/azure/openai/AzureOpenAiImageModel.java | 2 +- .../ai/azure/openai/AzureOpenAiImageOptions.java | 15 ++++++++++++++- .../ai/openai/OpenAiImageOptions.java | 15 ++++++++++++++- .../ai/openai/OpenAiImageOptionsTests.java | 11 ++++++++--- .../stabilityai/api/StabilityAiImageOptions.java | 15 ++++++++++++++- .../stabilityai/StabilityAiImageOptionsTests.java | 5 +++-- .../ai/zhipuai/ZhiPuAiImageOptions.java | 3 +-- .../modules/ROOT/pages/api/imageclient.adoc | 4 +++- .../springframework/ai/image/ImageOptions.java | 11 ++++++++++- .../ai/image/ImageOptionsBuilder.java | 4 ++-- .../DefaultImageModelObservationConvention.java | 2 +- 11 files changed, 71 insertions(+), 16 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index 3d73ea79074..88fe6ae6e51 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -181,7 +181,7 @@ private ImageGenerationOptions toOpenAiImageOptions(ImagePrompt prompt) { if (runtimeImageOptions.getResponseFormat() != null) { // b64_json or url imageGenerationOptions.setResponseFormat( - ImageGenerationResponseFormat.fromString(runtimeImageOptions.getResponseFormat().getValue())); + ImageGenerationResponseFormat.fromString(runtimeImageOptions.getResponseFormat())); } if (runtimeImageOptions.getWidth() != null && runtimeImageOptions.getHeight() != null) { imageGenerationOptions.setSize( diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index d8d818cd61a..30511d8488f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -150,7 +150,11 @@ public void setHeight(Integer height) { } @Override - public ImageResponseFormat getResponseFormat() { + public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } @@ -158,6 +162,10 @@ public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + public String getSize() { if (this.size != null) { return this.size; @@ -285,6 +293,11 @@ public Builder responseFormat(ImageResponseFormat responseFormat) { return this; } + public Builder responseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder width(Integer width) { this.options.setWidth(width); return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java index 73adc3e1107..bdebfd38dfb 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java @@ -159,7 +159,11 @@ public void setQuality(String quality) { } @Override - public ImageResponseFormat getResponseFormat() { + public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } @@ -167,6 +171,10 @@ public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + @Override public Integer getWidth() { if (this.width != null) { @@ -332,6 +340,11 @@ public Builder responseFormat(ImageResponseFormat responseFormat) { return this; } + public Builder responseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder width(Integer width) { this.options.setWidth(width); return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java index 1e9baf5cccd..b7506fbd8cf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java @@ -46,7 +46,8 @@ void testBuilderWithAllFields() { assertThat(options.getN()).isEqualTo(2); assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); - assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); + assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); @@ -74,6 +75,7 @@ void testCopy() { assertThat(copied.getModel()).isEqualTo(original.getModel()); assertThat(copied.getQuality()).isEqualTo(original.getQuality()); assertThat(copied.getResponseFormat()).isEqualTo(original.getResponseFormat()); + assertThat(copied.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum()); assertThat(copied.getWidth()).isEqualTo(original.getWidth()); assertThat(copied.getHeight()).isEqualTo(original.getHeight()); assertThat(copied.getSize()).isEqualTo(original.getSize()); @@ -87,6 +89,7 @@ void testCopy() { assertThat(copiedViaMethod.getModel()).isEqualTo(original.getModel()); assertThat(copiedViaMethod.getQuality()).isEqualTo(original.getQuality()); assertThat(copiedViaMethod.getResponseFormat()).isEqualTo(original.getResponseFormat()); + assertThat(copiedViaMethod.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum()); assertThat(copiedViaMethod.getWidth()).isEqualTo(original.getWidth()); assertThat(copiedViaMethod.getHeight()).isEqualTo(original.getHeight()); assertThat(copiedViaMethod.getSize()).isEqualTo(original.getSize()); @@ -110,7 +113,8 @@ void testSetters() { assertThat(options.getN()).isEqualTo(4); assertThat(options.getModel()).isEqualTo("dall-e-2"); assertThat(options.getQuality()).isEqualTo("standard"); - assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); + assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(512); assertThat(options.getHeight()).isEqualTo(512); assertThat(options.getSize()).isEqualTo("512x512"); @@ -224,7 +228,8 @@ void testFluentApiPattern() { assertThat(options.getN()).isEqualTo(1); assertThat(options.getModel()).isEqualTo("dall-e-3"); assertThat(options.getQuality()).isEqualTo("hd"); - assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.URL); + assertThat(options.getResponseFormat()).isEqualTo("url"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(1024); assertThat(options.getSize()).isEqualTo("1024x1024"); diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index b8bf336bad0..aa259390abc 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -329,7 +329,11 @@ public void setHeight(Integer height) { } @Override - public ImageResponseFormat getResponseFormat() { + public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; + } + + public ImageResponseFormat getResponseFormatEnum() { return this.responseFormat; } @@ -337,6 +341,10 @@ public void setResponseFormat(ImageResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public void setResponseFormat(String responseFormat) { + this.responseFormat = ImageResponseFormat.fromValue(responseFormat); + } + public Float getCfgScale() { return this.cfgScale; } @@ -461,6 +469,11 @@ public Builder responseFormat(ImageResponseFormat responseFormat) { return this; } + public Builder responseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + public Builder cfgScale(Float cfgScale) { this.options.setCfgScale(cfgScale); return this; diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java index 5f627f5cbe9..402de777b15 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageOptionsTests.java @@ -72,7 +72,8 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() { assertThat(options.getModel()).isEqualTo("runtime-model"); assertThat(options.getWidth()).isEqualTo(1024); assertThat(options.getHeight()).isEqualTo(768); - assertThat(options.getResponseFormat()).isEqualTo(ImageResponseFormat.APPLICATION_JSON); + assertThat(options.getResponseFormat()).isEqualTo("application/json"); + assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.APPLICATION_JSON); assertThat(options.getCfgScale()).isEqualTo(14.0f); assertThat(options.getClipGuidancePreset()).isEqualTo("FAST_GREEN"); assertThat(options.getSampler()).isEqualTo("DDPM"); @@ -137,7 +138,7 @@ public Integer getHeight() { } @Override - public ImageResponseFormat getResponseFormat() { + public String getResponseFormat() { return null; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index 53a3f65d3ab..116b7293830 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -23,7 +23,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; -import org.springframework.ai.image.ImageResponseFormat; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; /** @@ -97,7 +96,7 @@ public Integer getHeight() { @Override @JsonIgnore - public ImageResponseFormat getResponseFormat() { + public String getResponseFormat() { return null; } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc index fdc63875880..1e47ff1a7be 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/imageclient.adoc @@ -89,7 +89,9 @@ public interface ImageOptions extends ModelOptions { Integer getHeight(); -ImageResponseFormat getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 + String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 + + default ImageResponseFormat getResponseFormatAsEnum(); // convenience conversion helper } ---- diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java index 1e8967d16e8..aed2ee7ef2a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -18,6 +18,7 @@ import org.springframework.ai.model.ModelOptions; import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; /** * ImageOptions represent the common options, portable across different image generation @@ -38,7 +39,15 @@ public interface ImageOptions extends ModelOptions { Integer getHeight(); @Nullable - ImageResponseFormat getResponseFormat(); + String getResponseFormat(); + + default @Nullable ImageResponseFormat getResponseFormatAsEnum() { + String responseFormat = getResponseFormat(); + if (!StringUtils.hasText(responseFormat)) { + return null; + } + return ImageResponseFormat.fromValue(responseFormat); + } @Nullable String getStyle(); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index dbdaa8ade32..c3c21b196fb 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -100,8 +100,8 @@ public void setModel(String model) { } @Override - public ImageResponseFormat getResponseFormat() { - return this.responseFormat; + public String getResponseFormat() { + return (this.responseFormat != null) ? this.responseFormat.getValue() : null; } public void setResponseFormat(ImageResponseFormat responseFormat) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java index 12501ae76b7..cfe5ee07cba 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java @@ -85,7 +85,7 @@ public KeyValues getHighCardinalityKeyValues(ImageModelObservationContext contex // Request protected KeyValues requestImageFormat(KeyValues keyValues, ImageModelObservationContext context) { - ImageResponseFormat responseFormat = context.getRequest().getOptions().getResponseFormat(); + ImageResponseFormat responseFormat = context.getRequest().getOptions().getResponseFormatAsEnum(); if (responseFormat != null) { return keyValues.and( ImageModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_IMAGE_RESPONSE_FORMAT.asString(),