From 82a57d63e38a211525bc1b07862d5669e3bb669a Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Mon, 27 Oct 2025 15:58:36 -0400 Subject: [PATCH 1/8] feat(ai): server prompt templates --- .changeset/metal-ties-cry.md | 5 + common/api-review/ai.api.md | 50 +++ docs-devsite/_toc.yaml | 6 + docs-devsite/ai.googleaibackend.md | 46 +++ docs-devsite/ai.md | 53 +++ docs-devsite/ai.templatechatsession.md | 154 ++++++++ docs-devsite/ai.templategenerativemodel.md | 150 ++++++++ docs-devsite/ai.templateimagenmodel.md | 100 +++++ docs-devsite/ai.vertexaibackend.md | 46 +++ packages/ai/integration/constants.ts | 2 +- .../ai/integration/prompt-templates.test.ts | 65 ++++ packages/ai/src/api.test.ts | 16 +- packages/ai/src/api.ts | 45 ++- packages/ai/src/backend.test.ts | 26 +- packages/ai/src/backend.ts | 28 +- .../methods/chrome-adapter-browser.test.ts | 6 +- packages/ai/src/methods/chrome-adapter.ts | 3 +- packages/ai/src/methods/count-tokens.test.ts | 60 +-- packages/ai/src/methods/count-tokens.ts | 14 +- .../ai/src/methods/generate-content.test.ts | 233 +++++++++--- packages/ai/src/methods/generate-content.ts | 75 +++- .../src/methods/template-chat-session.test.ts | 359 ++++++++++++++++++ .../ai/src/methods/template-chat-session.ts | 230 +++++++++++ packages/ai/src/models/ai-model.test.ts | 106 +----- packages/ai/src/models/ai-model.ts | 61 +-- .../ai/src/models/generative-model.test.ts | 95 +++-- packages/ai/src/models/imagen-model.test.ts | 28 +- packages/ai/src/models/imagen-model.ts | 28 +- .../models/template-generative-model.test.ts | 117 ++++++ .../src/models/template-generative-model.ts | 118 ++++++ .../src/models/template-imagen-model.test.ts | 139 +++++++ .../ai/src/models/template-imagen-model.ts | 81 ++++ packages/ai/src/models/utils.test.ts | 142 +++++++ packages/ai/src/models/utils.ts | 71 ++++ packages/ai/src/requests/request.test.ts | 266 +++++++------ packages/ai/src/requests/request.ts | 152 ++++---- 36 files changed, 2645 insertions(+), 531 deletions(-) create mode 100644 .changeset/metal-ties-cry.md create mode 100644 docs-devsite/ai.templatechatsession.md create mode 100644 docs-devsite/ai.templategenerativemodel.md create mode 100644 docs-devsite/ai.templateimagenmodel.md create mode 100644 packages/ai/integration/prompt-templates.test.ts create mode 100644 packages/ai/src/methods/template-chat-session.test.ts create mode 100644 packages/ai/src/methods/template-chat-session.ts create mode 100644 packages/ai/src/models/template-generative-model.test.ts create mode 100644 packages/ai/src/models/template-generative-model.ts create mode 100644 packages/ai/src/models/template-imagen-model.test.ts create mode 100644 packages/ai/src/models/template-imagen-model.ts create mode 100644 packages/ai/src/models/utils.test.ts create mode 100644 packages/ai/src/models/utils.ts diff --git a/.changeset/metal-ties-cry.md b/.changeset/metal-ties-cry.md new file mode 100644 index 00000000000..f6c28a73b61 --- /dev/null +++ b/.changeset/metal-ties-cry.md @@ -0,0 +1,5 @@ +--- +'@firebase/ai': minor +--- + +Add support for Server Prompt Templates. diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 48347b8d65e..ca49fd95aa5 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -96,6 +96,10 @@ export interface AudioConversationController { export abstract class Backend { protected constructor(type: BackendType); readonly backendType: BackendType; + // @internal (undocumented) + abstract _getModelPath(project: string, model: string): string; + // @internal (undocumented) + abstract _getTemplatePath(project: string, templateId: string): string; } // @public @@ -561,9 +565,19 @@ export function getImagenModel(ai: AI, modelParams: ImagenModelParams, requestOp // @beta export function getLiveGenerativeModel(ai: AI, modelParams: LiveModelParams): LiveGenerativeModel; +// @beta +export function getTemplateGenerativeModel(ai: AI, requestOptions?: RequestOptions): TemplateGenerativeModel; + +// @beta +export function getTemplateImagenModel(ai: AI, requestOptions?: RequestOptions): TemplateImagenModel; + // @public export class GoogleAIBackend extends Backend { constructor(); + // (undocumented) + _getModelPath(project: string, model: string): string; + // (undocumented) + _getTemplatePath(project: string, templateId: string): string; } // Warning: (ae-internal-missing-underscore) The name "GoogleAICitationMetadata" should be prefixed with an underscore because the declaration is marked as @internal @@ -1304,6 +1318,38 @@ export class StringSchema extends Schema { toJSON(): SchemaRequest; } +// @beta +export class TemplateChatSession { + constructor(_apiSettings: ApiSettings, templateId: string, _history?: Content[], requestOptions?: RequestOptions | undefined); + getHistory(): Promise; + // (undocumented) + requestOptions?: RequestOptions | undefined; + sendMessage(request: string | Array, inputs?: object): Promise; + sendMessageStream(request: string | Array, inputs?: object): Promise; + // (undocumented) + templateId: string; +} + +// @beta +export class TemplateGenerativeModel { + constructor(ai: AI, requestOptions?: RequestOptions); + // @internal (undocumented) + _apiSettings: ApiSettings; + generateContent(templateId: string, templateVariables: object): Promise; + generateContentStream(templateId: string, templateVariables: object): Promise; + requestOptions?: RequestOptions; + startChat(templateId: string, history?: Content[]): TemplateChatSession; +} + +// @beta +export class TemplateImagenModel { + constructor(ai: AI, requestOptions?: RequestOptions); + // @internal (undocumented) + _apiSettings: ApiSettings; + generateImages(templateId: string, templateVariables: object): Promise>; + requestOptions?: RequestOptions; +} + // @public export interface TextPart { // (undocumented) @@ -1397,6 +1443,10 @@ export interface UsageMetadata { // @public export class VertexAIBackend extends Backend { constructor(location?: string); + // (undocumented) + _getModelPath(project: string, model: string): string; + // (undocumented) + _getTemplatePath(project: string, templateId: string): string; readonly location: string; } diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index 04d65f6c333..2004df9b39b 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -196,6 +196,12 @@ toc: path: /docs/reference/js/ai.startchatparams.md - title: StringSchema path: /docs/reference/js/ai.stringschema.md + - title: TemplateChatSession + path: /docs/reference/js/ai.templatechatsession.md + - title: TemplateGenerativeModel + path: /docs/reference/js/ai.templategenerativemodel.md + - title: TemplateImagenModel + path: /docs/reference/js/ai.templateimagenmodel.md - title: TextPart path: /docs/reference/js/ai.textpart.md - title: ThinkingConfig diff --git a/docs-devsite/ai.googleaibackend.md b/docs-devsite/ai.googleaibackend.md index 7ccf8834a0a..68d6724762a 100644 --- a/docs-devsite/ai.googleaibackend.md +++ b/docs-devsite/ai.googleaibackend.md @@ -27,6 +27,13 @@ export declare class GoogleAIBackend extends Backend | --- | --- | --- | | [(constructor)()](./ai.googleaibackend.md#googleaibackendconstructor) | | Creates a configuration object for the Gemini Developer API backend. | +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [\_getModelPath(project, model)](./ai.googleaibackend.md#googleaibackend_getmodelpath) | | | +| [\_getTemplatePath(project, templateId)](./ai.googleaibackend.md#googleaibackend_gettemplatepath) | | | + ## GoogleAIBackend.(constructor) Creates a configuration object for the Gemini Developer API backend. @@ -36,3 +43,42 @@ Creates a configuration object for the Gemini Developer API backend. ```typescript constructor(); ``` + +## GoogleAIBackend.\_getModelPath() + +Signature: + +```typescript +_getModelPath(project: string, model: string): string; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| project | string | | +| model | string | | + +Returns: + +string + +## GoogleAIBackend.\_getTemplatePath() + +Signature: + +```typescript +_getTemplatePath(project: string, templateId: string): string; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| project | string | | +| templateId | string | | + +Returns: + +string + diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index fabdbc5cc55..ab25706ff21 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -22,6 +22,8 @@ The Firebase AI Web SDK. | [getGenerativeModel(ai, modelParams, requestOptions)](./ai.md#getgenerativemodel_c63f46a) | Returns a [GenerativeModel](./ai.generativemodel.md#generativemodel_class) class with methods for inference and other functionality. | | [getImagenModel(ai, modelParams, requestOptions)](./ai.md#getimagenmodel_e1f6645) | Returns an [ImagenModel](./ai.imagenmodel.md#imagenmodel_class) class with methods for using Imagen.Only Imagen 3 models (named imagen-3.0-*) are supported. | | [getLiveGenerativeModel(ai, modelParams)](./ai.md#getlivegenerativemodel_f2099ac) | (Public Preview) Returns a [LiveGenerativeModel](./ai.livegenerativemodel.md#livegenerativemodel_class) class for real-time, bidirectional communication.The Live API is only supported in modern browser windows and Node >= 22. | +| [getTemplateGenerativeModel(ai, requestOptions)](./ai.md#gettemplategenerativemodel_9476bbc) | (Public Preview) Returns a [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) class for executing server-side templates. | +| [getTemplateImagenModel(ai, requestOptions)](./ai.md#gettemplateimagenmodel_9476bbc) | (Public Preview) Returns a [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) class for executing server-side Imagen templates. | | function(liveSession, ...) | | [startAudioConversation(liveSession, options)](./ai.md#startaudioconversation_01c8e7f) | (Public Preview) Starts a real-time, bidirectional audio conversation with the model. This helper function manages the complexities of microphone access, audio recording, playback, and interruptions. | @@ -47,6 +49,9 @@ The Firebase AI Web SDK. | [ObjectSchema](./ai.objectschema.md#objectschema_class) | Schema class for "object" types. The properties param must be a map of Schema objects. | | [Schema](./ai.schema.md#schema_class) | Parent class encompassing all Schema types, with static methods that allow building specific Schema types. This class can be converted with JSON.stringify() into a JSON string accepted by Vertex AI REST endpoints. (This string conversion is automatically done when calling SDK methods.) | | [StringSchema](./ai.stringschema.md#stringschema_class) | Schema class for "string" types. Can be used with or without enum values. | +| [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) | (Public Preview) A chat session that enables sending chat messages and stores the history of sent and received messages so far.This session is for multi-turn chats using a server-side template. It should be instantiated with [TemplateGenerativeModel.startChat()](./ai.templategenerativemodel.md#templategenerativemodelstartchat). | +| [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) | (Public Preview) [GenerativeModel](./ai.generativemodel.md#generativemodel_class) APIs that execute on a server-side template.This class should only be instantiated with [getTemplateGenerativeModel()](./ai.md#gettemplategenerativemodel_9476bbc). | +| [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) | (Public Preview) Class for Imagen model APIs that execute on a server-side template.This class should only be instantiated with [getTemplateImagenModel()](./ai.md#gettemplateimagenmodel_9476bbc). | | [VertexAIBackend](./ai.vertexaibackend.md#vertexaibackend_class) | Configuration class for the Vertex AI Gemini API.Use this with [AIOptions](./ai.aioptions.md#aioptions_interface) when initializing the AI service via [getAI()](./ai.md#getai_a94a413) to specify the Vertex AI Gemini API as the backend. | ## Interfaces @@ -339,6 +344,54 @@ export declare function getLiveGenerativeModel(ai: AI, modelParams: LiveModelPar If the `apiKey` or `projectId` fields are missing in your Firebase config. +### getTemplateGenerativeModel(ai, requestOptions) {:#gettemplategenerativemodel_9476bbc} + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Returns a [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) class for executing server-side templates. + +Signature: + +```typescript +export declare function getTemplateGenerativeModel(ai: AI, requestOptions?: RequestOptions): TemplateGenerativeModel; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | An [AI](./ai.ai.md#ai_interface) instance. | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | Additional options to use when making requests. | + +Returns: + +[TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) + +### getTemplateImagenModel(ai, requestOptions) {:#gettemplateimagenmodel_9476bbc} + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Returns a [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) class for executing server-side Imagen templates. + +Signature: + +```typescript +export declare function getTemplateImagenModel(ai: AI, requestOptions?: RequestOptions): TemplateImagenModel; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | An [AI](./ai.ai.md#ai_interface) instance. | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | Additional options to use when making requests. | + +Returns: + +[TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) + ## function(liveSession, ...) ### startAudioConversation(liveSession, options) {:#startaudioconversation_01c8e7f} diff --git a/docs-devsite/ai.templatechatsession.md b/docs-devsite/ai.templatechatsession.md new file mode 100644 index 00000000000..41f5e71d97a --- /dev/null +++ b/docs-devsite/ai.templatechatsession.md @@ -0,0 +1,154 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# TemplateChatSession class +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +A chat session that enables sending chat messages and stores the history of sent and received messages so far. + +This session is for multi-turn chats using a server-side template. It should be instantiated with [TemplateGenerativeModel.startChat()](./ai.templategenerativemodel.md#templategenerativemodelstartchat). + +Signature: + +```typescript +export declare class TemplateChatSession +``` + +## Constructors + +| Constructor | Modifiers | Description | +| --- | --- | --- | +| [(constructor)(\_apiSettings, templateId, \_history, requestOptions)](./ai.templatechatsession.md#templatechatsessionconstructor) | | (Public Preview) Constructs a new instance of the TemplateChatSession class | + +## Properties + +| Property | Modifiers | Type | Description | +| --- | --- | --- | --- | +| [requestOptions](./ai.templatechatsession.md#templatechatsessionrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) \| undefined | (Public Preview) | +| [templateId](./ai.templatechatsession.md#templatechatsessiontemplateid) | | string | (Public Preview) | + +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [getHistory()](./ai.templatechatsession.md#templatechatsessiongethistory) | | (Public Preview) Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. | +| [sendMessage(request, inputs)](./ai.templatechatsession.md#templatechatsessionsendmessage) | | (Public Preview) Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface). | +| [sendMessageStream(request, inputs)](./ai.templatechatsession.md#templatechatsessionsendmessagestream) | | (Public Preview) Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | + +## TemplateChatSession.(constructor) + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + + Constructs a new instance of the `TemplateChatSession` class + +Signature: + +```typescript +constructor(_apiSettings: ApiSettings, templateId: string, _history?: Content[], requestOptions?: RequestOptions | undefined); +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| \_apiSettings | ApiSettings | | +| templateId | string | | +| \_history | [Content](./ai.content.md#content_interface)\[\] | | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) \| undefined | | + +## TemplateChatSession.requestOptions + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Signature: + +```typescript +requestOptions?: RequestOptions | undefined; +``` + +## TemplateChatSession.templateId + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Signature: + +```typescript +templateId: string; +``` + +## TemplateChatSession.getHistory() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. + +Signature: + +```typescript +getHistory(): Promise; +``` +Returns: + +Promise<[Content](./ai.content.md#content_interface)\[\]> + +## TemplateChatSession.sendMessage() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface). + +Signature: + +```typescript +sendMessage(request: string | Array, inputs?: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| request | string \| Array<string \| [Part](./ai.md#part)> | The user message to store in the history | +| inputs | object | A key-value map of variables to populate the template with. This should likely include the user message. | + +Returns: + +Promise<[GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface)> + +## TemplateChatSession.sendMessageStream() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. + +Signature: + +```typescript +sendMessageStream(request: string | Array, inputs?: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| request | string \| Array<string \| [Part](./ai.md#part)> | The message to send to the model. | +| inputs | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface)> + diff --git a/docs-devsite/ai.templategenerativemodel.md b/docs-devsite/ai.templategenerativemodel.md new file mode 100644 index 00000000000..d7ab0955f2f --- /dev/null +++ b/docs-devsite/ai.templategenerativemodel.md @@ -0,0 +1,150 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# TemplateGenerativeModel class +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +[GenerativeModel](./ai.generativemodel.md#generativemodel_class) APIs that execute on a server-side template. + +This class should only be instantiated with [getTemplateGenerativeModel()](./ai.md#gettemplategenerativemodel_9476bbc). + +Signature: + +```typescript +export declare class TemplateGenerativeModel +``` + +## Constructors + +| Constructor | Modifiers | Description | +| --- | --- | --- | +| [(constructor)(ai, requestOptions)](./ai.templategenerativemodel.md#templategenerativemodelconstructor) | | (Public Preview) Constructs a new instance of the TemplateGenerativeModel class | + +## Properties + +| Property | Modifiers | Type | Description | +| --- | --- | --- | --- | +| [requestOptions](./ai.templategenerativemodel.md#templategenerativemodelrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | (Public Preview) Additional options to use when making requests. | + +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [generateContent(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [startChat(templateId, history)](./ai.templategenerativemodel.md#templategenerativemodelstartchat) | | (Public Preview) Gets a new [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) instance which can be used for multi-turn chats. | + +## TemplateGenerativeModel.(constructor) + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + + Constructs a new instance of the `TemplateGenerativeModel` class + +Signature: + +```typescript +constructor(ai: AI, requestOptions?: RequestOptions); +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | | + +## TemplateGenerativeModel.requestOptions + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Additional options to use when making requests. + +Signature: + +```typescript +requestOptions?: RequestOptions; +``` + +## TemplateGenerativeModel.generateContent() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). + +Signature: + +```typescript +generateContent(templateId: string, templateVariables: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface)> + +## TemplateGenerativeModel.generateContentStream() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. + +Signature: + +```typescript +generateContentStream(templateId: string, templateVariables: object): Promise; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface)> + +## TemplateGenerativeModel.startChat() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Gets a new [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) instance which can be used for multi-turn chats. + +Signature: + +```typescript +startChat(templateId: string, history?: Content[]): TemplateChatSession; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| history | [Content](./ai.content.md#content_interface)\[\] | An array of [Content](./ai.content.md#content_interface) objects to initialize the chat history with. | + +Returns: + +[TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) + diff --git a/docs-devsite/ai.templateimagenmodel.md b/docs-devsite/ai.templateimagenmodel.md new file mode 100644 index 00000000000..2d86071993f --- /dev/null +++ b/docs-devsite/ai.templateimagenmodel.md @@ -0,0 +1,100 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# TemplateImagenModel class +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Class for Imagen model APIs that execute on a server-side template. + +This class should only be instantiated with [getTemplateImagenModel()](./ai.md#gettemplateimagenmodel_9476bbc). + +Signature: + +```typescript +export declare class TemplateImagenModel +``` + +## Constructors + +| Constructor | Modifiers | Description | +| --- | --- | --- | +| [(constructor)(ai, requestOptions)](./ai.templateimagenmodel.md#templateimagenmodelconstructor) | | (Public Preview) Constructs a new instance of the TemplateImagenModel class | + +## Properties + +| Property | Modifiers | Type | Description | +| --- | --- | --- | --- | +| [requestOptions](./ai.templateimagenmodel.md#templateimagenmodelrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | (Public Preview) Additional options to use when making requests. | + +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [generateImages(templateId, templateVariables)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | + +## TemplateImagenModel.(constructor) + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + + Constructs a new instance of the `TemplateImagenModel` class + +Signature: + +```typescript +constructor(ai: AI, requestOptions?: RequestOptions); +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| ai | [AI](./ai.ai.md#ai_interface) | | +| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) | | + +## TemplateImagenModel.requestOptions + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Additional options to use when making requests. + +Signature: + +```typescript +requestOptions?: RequestOptions; +``` + +## TemplateImagenModel.generateImages() + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). + +Signature: + +```typescript +generateImages(templateId: string, templateVariables: object): Promise>; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| templateId | string | The ID of the server-side template to execute. | +| templateVariables | object | A key-value map of variables to populate the template with. | + +Returns: + +Promise<[ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface)<[ImagenInlineImage](./ai.imageninlineimage.md#imageninlineimage_interface)>> + diff --git a/docs-devsite/ai.vertexaibackend.md b/docs-devsite/ai.vertexaibackend.md index 88424b75c45..e2e7fae1839 100644 --- a/docs-devsite/ai.vertexaibackend.md +++ b/docs-devsite/ai.vertexaibackend.md @@ -33,6 +33,13 @@ export declare class VertexAIBackend extends Backend | --- | --- | --- | --- | | [location](./ai.vertexaibackend.md#vertexaibackendlocation) | | string | The region identifier. See [Vertex AI locations](https://firebase.google.com/docs/vertex-ai/locations#available-locations) for a list of supported locations. | +## Methods + +| Method | Modifiers | Description | +| --- | --- | --- | +| [\_getModelPath(project, model)](./ai.vertexaibackend.md#vertexaibackend_getmodelpath) | | | +| [\_getTemplatePath(project, templateId)](./ai.vertexaibackend.md#vertexaibackend_gettemplatepath) | | | + ## VertexAIBackend.(constructor) Creates a configuration object for the Vertex AI backend. @@ -58,3 +65,42 @@ The region identifier. See [Vertex AI locations](https://firebase.google.com/doc ```typescript readonly location: string; ``` + +## VertexAIBackend.\_getModelPath() + +Signature: + +```typescript +_getModelPath(project: string, model: string): string; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| project | string | | +| model | string | | + +Returns: + +string + +## VertexAIBackend.\_getTemplatePath() + +Signature: + +```typescript +_getTemplatePath(project: string, templateId: string): string; +``` + +#### Parameters + +| Parameter | Type | Description | +| --- | --- | --- | +| project | string | | +| templateId | string | | + +Returns: + +string + diff --git a/packages/ai/integration/constants.ts b/packages/ai/integration/constants.ts index f4a74e75039..99a65f31c54 100644 --- a/packages/ai/integration/constants.ts +++ b/packages/ai/integration/constants.ts @@ -44,7 +44,7 @@ function formatConfigAsString(config: { ai: AI; model: string }): string { const backends: readonly Backend[] = [ new GoogleAIBackend(), - new VertexAIBackend() + new VertexAIBackend('global') ]; const backendNames: Map = new Map([ diff --git a/packages/ai/integration/prompt-templates.test.ts b/packages/ai/integration/prompt-templates.test.ts new file mode 100644 index 00000000000..4898bdf1c1f --- /dev/null +++ b/packages/ai/integration/prompt-templates.test.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; +import { + BackendType, + getTemplateGenerativeModel, + getTemplateImagenModel +} from '../src'; +import { testConfigs } from './constants'; + +const templateBackendSuffix = ( + backendType: BackendType +): 'googleai' | 'vertexai' => + backendType === BackendType.GOOGLE_AI ? 'googleai' : 'vertexai'; + +describe('Prompt templates', function () { + this.timeout(20_000); + testConfigs.forEach(testConfig => { + describe(`${testConfig.toString()}`, () => { + describe('Generative Model', () => { + it('successfully generates content', async () => { + const model = getTemplateGenerativeModel(testConfig.ai, { + baseUrl: 'https://staging-firebasevertexai.sandbox.googleapis.com' + }); + const { response } = await model.generateContent( + `sassy-greeting-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { name: 'John' } + ); + expect(response.text()).to.contain('John'); // Template asks to address directly by name + }); + }); + describe('Imagen model', async () => { + it('successfully generates images', async () => { + const model = getTemplateImagenModel(testConfig.ai, { + baseUrl: 'https://staging-firebasevertexai.sandbox.googleapis.com' + }); + const { images } = await model.generateImages( + `portrait-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { animal: 'Rhino' } + ); + expect(images.length).to.equal(2); // We ask for two images in the prompt template + }); + }); + }); + }); +}); diff --git a/packages/ai/src/api.test.ts b/packages/ai/src/api.test.ts index 65ecbbdcba8..e28868e388f 100644 --- a/packages/ai/src/api.test.ts +++ b/packages/ai/src/api.test.ts @@ -22,7 +22,11 @@ import { LiveGenerativeModel, getGenerativeModel, getImagenModel, - getLiveGenerativeModel + getLiveGenerativeModel, + getTemplateGenerativeModel, + TemplateGenerativeModel, + getTemplateImagenModel, + TemplateImagenModel } from './api'; import { expect } from 'chai'; import { AI } from './public-types'; @@ -281,4 +285,14 @@ describe('Top level API', () => { 'publishers/google/models/my-model' ); }); + it('getTemplateGenerativeModel gets a TemplateGenerativeModel', () => { + expect(getTemplateGenerativeModel(fakeAI)).to.be.an.instanceOf( + TemplateGenerativeModel + ); + }); + it('getImagenModel gets a TemplateImagenModel', () => { + expect(getTemplateImagenModel(fakeAI)).to.be.an.instanceOf( + TemplateImagenModel + ); + }); }); diff --git a/packages/ai/src/api.ts b/packages/ai/src/api.ts index 6e56aea793c..4f9201a52fd 100644 --- a/packages/ai/src/api.ts +++ b/packages/ai/src/api.ts @@ -39,12 +39,23 @@ import { import { encodeInstanceIdentifier } from './helpers'; import { GoogleAIBackend } from './backend'; import { WebSocketHandlerImpl } from './websocket'; +import { TemplateGenerativeModel } from './models/template-generative-model'; +import { TemplateImagenModel } from './models/template-imagen-model'; export { ChatSession } from './methods/chat-session'; +export { TemplateChatSession } from './methods/template-chat-session'; export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; -export { AIModel, GenerativeModel, LiveGenerativeModel, ImagenModel, AIError }; +export { + AIModel, + GenerativeModel, + LiveGenerativeModel, + ImagenModel, + TemplateGenerativeModel, + TemplateImagenModel, + AIError +}; export { Backend, VertexAIBackend, GoogleAIBackend } from './backend'; export { startAudioConversation, @@ -202,3 +213,35 @@ export function getLiveGenerativeModel( const webSocketHandler = new WebSocketHandlerImpl(); return new LiveGenerativeModel(ai, modelParams, webSocketHandler); } + +/** + * Returns a {@link TemplateGenerativeModel} class for executing server-side + * templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateGenerativeModel( + ai: AI, + requestOptions?: RequestOptions +): TemplateGenerativeModel { + return new TemplateGenerativeModel(ai, requestOptions); +} + +/** + * Returns a {@link TemplateImagenModel} class for executing server-side + * Imagen templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateImagenModel( + ai: AI, + requestOptions?: RequestOptions +): TemplateImagenModel { + return new TemplateImagenModel(ai, requestOptions); +} diff --git a/packages/ai/src/backend.test.ts b/packages/ai/src/backend.test.ts index 0c6609277e3..46d6507a499 100644 --- a/packages/ai/src/backend.test.ts +++ b/packages/ai/src/backend.test.ts @@ -18,7 +18,7 @@ import { expect } from 'chai'; import { GoogleAIBackend, VertexAIBackend } from './backend'; import { BackendType } from './public-types'; -import { DEFAULT_LOCATION } from './constants'; +import { DEFAULT_API_VERSION, DEFAULT_LOCATION } from './constants'; describe('Backend', () => { describe('GoogleAIBackend', () => { @@ -26,6 +26,18 @@ describe('Backend', () => { const backend = new GoogleAIBackend(); expect(backend.backendType).to.equal(BackendType.GOOGLE_AI); }); + it('getModelPath', () => { + const backend = new GoogleAIBackend(); + expect(backend._getModelPath('my-project', 'model-name')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/model-name` + ); + }); + it('getTemplatePath', () => { + const backend = new GoogleAIBackend(); + expect(backend._getTemplatePath('my-project', 'template-id')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/templates/template-id` + ); + }); }); describe('VertexAIBackend', () => { it('set backendType to VERTEX_AI', () => { @@ -48,5 +60,17 @@ describe('Backend', () => { expect(backend.backendType).to.equal(BackendType.VERTEX_AI); expect(backend.location).to.equal(DEFAULT_LOCATION); }); + it('getModelPath', () => { + const backend = new VertexAIBackend(); + expect(backend._getModelPath('my-project', 'model-name')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/locations/${backend.location}/model-name` + ); + }); + it('getTemplatePath', () => { + const backend = new VertexAIBackend(); + expect(backend._getTemplatePath('my-project', 'template-id')).to.equal( + `/${DEFAULT_API_VERSION}/projects/my-project/locations/${backend.location}/templates/template-id` + ); + }); }); }); diff --git a/packages/ai/src/backend.ts b/packages/ai/src/backend.ts index 7209828122b..21852b3608d 100644 --- a/packages/ai/src/backend.ts +++ b/packages/ai/src/backend.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { DEFAULT_LOCATION } from './constants'; +import { DEFAULT_API_VERSION, DEFAULT_LOCATION } from './constants'; import { BackendType } from './public-types'; /** @@ -39,6 +39,16 @@ export abstract class Backend { protected constructor(type: BackendType) { this.backendType = type; } + + /** + * @internal + */ + abstract _getModelPath(project: string, model: string): string; + + /** + * @internal + */ + abstract _getTemplatePath(project: string, templateId: string): string; } /** @@ -56,6 +66,14 @@ export class GoogleAIBackend extends Backend { constructor() { super(BackendType.GOOGLE_AI); } + + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/${model}`; + } + + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/templates/${templateId}`; + } } /** @@ -89,4 +107,12 @@ export class VertexAIBackend extends Backend { this.location = location; } } + + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/${model}`; + } + + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/templates/${templateId}`; + } } diff --git a/packages/ai/src/methods/chrome-adapter-browser.test.ts b/packages/ai/src/methods/chrome-adapter-browser.test.ts index e37a08bf1a9..1b851c3083c 100644 --- a/packages/ai/src/methods/chrome-adapter-browser.test.ts +++ b/packages/ai/src/methods/chrome-adapter-browser.test.ts @@ -29,6 +29,7 @@ import { import { match, stub } from 'sinon'; import { GenerateContentRequest, AIErrorCode, InferenceMode } from '../types'; import { Schema } from '../api'; +import { isNode } from '@firebase/util'; use(sinonChai); use(chaiAsPromised); @@ -53,6 +54,9 @@ async function toStringArray( } describe('ChromeAdapter', () => { + if (isNode()) { + return; + } describe('constructor', () => { it('sets image as expected input type by default', async () => { const languageModelProvider = { @@ -833,7 +837,7 @@ describe('chromeAdapterFactory', () => { const fakeLanguageModel = {} as LanguageModel; const adapter = chromeAdapterFactory( InferenceMode.PREFER_ON_DEVICE, - { LanguageModel: fakeLanguageModel } as Window, + { LanguageModel: fakeLanguageModel } as any, { createOptions: {} } ); expect(adapter?.languageModelProvider).to.equal(fakeLanguageModel); diff --git a/packages/ai/src/methods/chrome-adapter.ts b/packages/ai/src/methods/chrome-adapter.ts index 839276814bb..709084638c5 100644 --- a/packages/ai/src/methods/chrome-adapter.ts +++ b/packages/ai/src/methods/chrome-adapter.ts @@ -400,7 +400,8 @@ export function chromeAdapterFactory( // Do not initialize a ChromeAdapter if we are not in hybrid mode. if (typeof window !== 'undefined' && mode) { return new ChromeAdapterImpl( - (window as Window).LanguageModel as LanguageModel, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (window as any).LanguageModel as LanguageModel, mode, params ); diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index aabf06a841a..80da197790d 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -77,16 +77,17 @@ describe('countTokens()', () => { fakeChromeAdapter ); expect(result.totalTokens).to.equal(6); - expect(result.totalBillableCharacters).to.equal(16); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('total tokens with modality details', async () => { @@ -104,18 +105,17 @@ describe('countTokens()', () => { fakeChromeAdapter ); expect(result.totalTokens).to.equal(1837); - expect(result.totalBillableCharacters).to.equal(117); - expect(result.promptTokensDetails?.[0].modality).to.equal('IMAGE'); - expect(result.promptTokensDetails?.[0].tokenCount).to.equal(1806); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('total tokens no billable characters', async () => { @@ -135,14 +135,16 @@ describe('countTokens()', () => { expect(result.totalTokens).to.equal(258); expect(result).to.not.have.property('totalBillableCharacters'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('contents'); - }), - undefined + }) ); }); it('model not found', async () => { @@ -187,12 +189,14 @@ describe('countTokens()', () => { ); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.COUNT_TOKENS, - fakeGoogleAIApiSettings, - false, - JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')), - undefined + { + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')) ); }); }); diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index ecd86a82912..2481c864353 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -44,12 +44,14 @@ export async function countTokensOnCloud( body = JSON.stringify(params); } const response = await makeRequest( - model, - Task.COUNT_TOKENS, - apiSettings, - false, - body, - requestOptions + { + model, + task: Task.COUNT_TOKENS, + apiSettings, + stream: false, + requestOptions + }, + body ); return response.json(); } diff --git a/packages/ai/src/methods/generate-content.test.ts b/packages/ai/src/methods/generate-content.test.ts index 40dc7c7b36e..c125bf2f914 100644 --- a/packages/ai/src/methods/generate-content.test.ts +++ b/packages/ai/src/methods/generate-content.test.ts @@ -19,9 +19,16 @@ import { expect, use } from 'chai'; import Sinon, { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { getMockResponse } from '../../test-utils/mock-response'; +import { + getMockResponse, + getMockResponseStreaming +} from '../../test-utils/mock-response'; import * as request from '../requests/request'; -import { generateContent } from './generate-content'; +import { + generateContent, + templateGenerateContent, + templateGenerateContentStream +} from './generate-content'; import { AIErrorCode, GenerateContentRequest, @@ -110,12 +117,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.include('Mountain View, California'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - JSON.stringify(fakeRequestParams), - undefined + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('long response', async () => { @@ -134,11 +143,14 @@ describe('generateContent()', () => { expect(result.response.text()).to.include('Use Freshly Ground Coffee'); expect(result.response.text()).to.include('30 minutes of brewing'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('long response with token details', async () => { @@ -169,11 +181,14 @@ describe('generateContent()', () => { result.response.usageMetadata?.candidatesTokensDetails?.[0].tokenCount ).to.equal(76); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('citations', async () => { @@ -196,11 +211,14 @@ describe('generateContent()', () => { result.response.candidates?.[0].citationMetadata?.citations.length ).to.equal(3); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('google search grounding', async () => { @@ -243,11 +261,14 @@ describe('generateContent()', () => { .undefined; expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); it('url context', async () => { @@ -293,10 +314,12 @@ describe('generateContent()', () => { .be.undefined; expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, match.any ); }); @@ -335,11 +358,14 @@ describe('generateContent()', () => { ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('finishReason safety', async () => { @@ -357,11 +383,14 @@ describe('generateContent()', () => { ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('empty content', async () => { @@ -379,11 +408,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.equal(''); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('empty part', async () => { @@ -417,11 +449,14 @@ describe('generateContent()', () => { ); expect(result.response.text()).to.include('Some text'); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - match.any + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined + }, + JSON.stringify(fakeRequestParams) ); }); it('image rejected (400)', async () => { @@ -509,12 +544,14 @@ describe('generateContent()', () => { ); expect(makeRequestStub).to.be.calledWith( - 'model', - Task.GENERATE_CONTENT, - fakeGoogleAIApiSettings, - false, - JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)), - undefined + { + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: match.any + }, + JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)) ); }); }); @@ -540,3 +577,83 @@ describe('generateContent()', () => { expect(generateContentStub).to.be.calledWith(fakeRequestParams); }); }); + +describe('templateGenerateContent', () => { + afterEach(() => { + restore(); + }); + it('should call makeRequest with correct parameters and process the response', async () => { + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const templateId = 'my-template'; + const templateParams = { name: 'world' }; + const requestOptions = { timeout: 5000 }; + + const result = await templateGenerateContent( + fakeApiSettings, + templateId, + templateParams, + requestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templateGenerateContent', + templateId, + apiSettings: fakeApiSettings, + stream: false, + requestOptions + }, + JSON.stringify(templateParams) + ); + expect(result.response.text()).to.include('Mountain View, California'); + }); +}); + +describe('templateGenerateContentStream', () => { + afterEach(() => { + restore(); + }); + it('should call makeRequest with correct parameters for streaming', async () => { + const mockResponse = getMockResponseStreaming( + 'vertexAI', + 'streaming-success-basic-reply-short.txt' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const templateId = 'my-stream-template'; + const templateParams = { name: 'streaming world' }; + const requestOptions = { timeout: 10000 }; + + const result = await templateGenerateContentStream( + fakeApiSettings, + templateId, + templateParams, + requestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templateStreamGenerateContent', + templateId, + apiSettings: fakeApiSettings, + stream: true, + requestOptions + }, + JSON.stringify(templateParams) + ); + + // Verify the stream processing part + for await (const item of result.stream) { + expect(item.text()).to.not.be.empty; + } + const response = await result.response; + expect(response.text()).to.include('Cheyenne'); + }); +}); diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index a2fb29e20d1..cffc4b48413 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -41,12 +41,14 @@ async function generateContentStreamOnCloud( params = GoogleAIMapper.mapGenerateContentRequest(params); } return makeRequest( - model, - Task.STREAM_GENERATE_CONTENT, - apiSettings, - /* stream */ true, - JSON.stringify(params), - requestOptions + { + task: Task.STREAM_GENERATE_CONTENT, + model, + apiSettings, + stream: true, + requestOptions + }, + JSON.stringify(params) ); } @@ -77,15 +79,64 @@ async function generateContentOnCloud( params = GoogleAIMapper.mapGenerateContentRequest(params); } return makeRequest( - model, - Task.GENERATE_CONTENT, - apiSettings, - /* stream */ false, - JSON.stringify(params), - requestOptions + { + model, + task: Task.GENERATE_CONTENT, + apiSettings, + stream: false, + requestOptions + }, + JSON.stringify(params) ); } +export async function templateGenerateContent( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions +): Promise { + const response = await makeRequest( + { + task: 'templateGenerateContent', + templateId, + apiSettings, + stream: false, + requestOptions + }, + JSON.stringify(templateParams) + ); + const generateContentResponse = await processGenerateContentResponse( + response, + apiSettings + ); + const enhancedResponse = createEnhancedContentResponse( + generateContentResponse + ); + return { + response: enhancedResponse + }; +} + +export async function templateGenerateContentStream( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions +): Promise { + const response = await makeRequest( + { + task: 'templateStreamGenerateContent', + templateId, + apiSettings, + stream: true, + requestOptions + }, + JSON.stringify(templateParams) + ); + return processStream(response, apiSettings); +} + export async function generateContent( apiSettings: ApiSettings, model: string, diff --git a/packages/ai/src/methods/template-chat-session.test.ts b/packages/ai/src/methods/template-chat-session.test.ts new file mode 100644 index 00000000000..79161143eba --- /dev/null +++ b/packages/ai/src/methods/template-chat-session.test.ts @@ -0,0 +1,359 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { restore, stub } from 'sinon'; +import { VertexAIBackend } from '../backend'; +import * as generateContentMethods from './generate-content'; +import { TemplateChatSession } from './template-chat-session'; +import { ApiSettings } from '../types/internal'; +import { GenerateContentResult, Part, Role } from '../types'; + +use(sinonChai); +use(chaiAsPromised); + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + appId: 'my-appid', + location: 'us-central1', + backend: new VertexAIBackend() +}; + +const TEMPLATE_ID = 'my-chat-template'; + +const FAKE_MODEL_RESPONSE_1 = { + response: { + candidates: [ + { + index: 0, + content: { + role: 'model' as Role, + parts: [{ text: 'Response 1' }] + } + } + ] + } +}; + +const FAKE_MODEL_RESPONSE_2 = { + response: { + candidates: [ + { + index: 0, + content: { + role: 'model' as Role, + parts: [{ text: 'Response 2' }] + } + } + ] + } +}; + +describe('TemplateChatSession', () => { + let templateGenerateContentStub: sinon.SinonStub; + let templateGenerateContentStreamStub: sinon.SinonStub; + + beforeEach(() => { + templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ); + templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ); + }); + + afterEach(() => { + restore(); + }); + + describe('history and state management', () => { + it('should update history correctly after a single successful call', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage('Request 1'); + const history = await chat.getHistory(); + expect(history).to.have.lengthOf(2); + expect(history[0].role).to.equal('user'); + expect(history[0].parts[0].text).to.equal('Request 1'); + expect(history[1].role).to.equal('model'); + expect(history[1].parts[0].text).to.equal('Response 1'); + }); + + it('should maintain history over multiple turns', async () => { + templateGenerateContentStub + .onFirstCall() + .resolves(FAKE_MODEL_RESPONSE_1 as GenerateContentResult); + templateGenerateContentStub + .onSecondCall() + .resolves(FAKE_MODEL_RESPONSE_2 as GenerateContentResult); + + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage('Request 1'); + await chat.sendMessage('Request 2'); + + const history = await chat.getHistory(); + expect(history).to.have.lengthOf(4); + expect(history[0].parts[0].text).to.equal('Request 1'); + expect(history[1].parts[0].text).to.equal('Response 1'); + expect(history[2].parts[0].text).to.equal('Request 2'); + expect(history[3].parts[0].text).to.equal('Response 2'); + }); + + it('should handle sequential calls to sendMessage and sendMessageStream', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_2.response) + }); + + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage('Request 1'); + await chat.sendMessageStream('Request 2'); + + const history = await chat.getHistory(); + expect(history).to.have.lengthOf(4); + expect(history[2].parts[0].text).to.equal('Request 2'); + expect(history[3].parts[0].text).to.equal('Response 2'); + }); + + it('should be able to be initialized with a history', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_2 as GenerateContentResult + ); + const initialHistory = [ + { role: 'user' as Role, parts: [{ text: 'Request 1' }] }, + FAKE_MODEL_RESPONSE_1.response.candidates[0].content + ]; + const chat = new TemplateChatSession( + fakeApiSettings, + TEMPLATE_ID, + initialHistory + ); + await chat.sendMessage('Request 2'); + const history = await chat.getHistory(); + expect(history).to.have.lengthOf(4); + expect(history[0].parts[0].text).to.equal('Request 1'); + expect(history[1].parts[0].text).to.equal('Response 1'); + expect(history[2].parts[0].text).to.equal('Request 2'); + expect(history[3].parts[0].text).to.equal('Response 2'); + }); + }); + + describe('error handling', () => { + it('templateGenerateContent errors should be catchable', async () => { + templateGenerateContentStub.rejects('failed'); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await expect(chat.sendMessage('Request 1')).to.be.rejected; + }); + + it('templateGenerateContentStream errors should be catchable', async () => { + templateGenerateContentStreamStub.rejects('failed'); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await expect(chat.sendMessageStream('Request 1')).to.be.rejected; + }); + + it('getHistory should fail if templateGenerateContent fails', async () => { + templateGenerateContentStub.rejects('failed'); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await expect(chat.sendMessage('Request 1')).to.be.rejected; + await expect(chat.getHistory()).to.be.rejected; + }); + + it('getHistory should fail if templateGenerateContentStream fails', async () => { + templateGenerateContentStreamStub.rejects('failed'); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await expect(chat.sendMessageStream('Request 1')).to.be.rejected; + }); + + it('should not update history if response has no candidates', async () => { + templateGenerateContentStub.resolves({ + response: {} + } as GenerateContentResult); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage('Request 1'); + const history = await chat.getHistory(); + expect(history).to.be.empty; + }); + }); + + describe('input variations for sendMessage', () => { + it('should handle request as a single string', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage('Just a string'); + const history = await chat.getHistory(); + expect(history[0].parts[0].text).to.equal('Just a string'); + }); + + it('should handle request as an array of strings', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage(['string 1', 'string 2']); + const history = await chat.getHistory(); + expect(history[0].parts).to.deep.equal([ + { text: 'string 1' }, + { text: 'string 2' } + ]); + }); + + it('should handle request as an array of Part objects', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const parts: Part[] = [{ text: 'part 1' }, { text: 'part 2' }]; + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessage(parts); + const history = await chat.getHistory(); + expect(history[0].parts).to.deep.equal(parts); + }); + + it('should pass inputs to templateGenerateContent', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + const inputs = { someVar: 'someValue' }; + await chat.sendMessage('A request', inputs); + expect(templateGenerateContentStub).to.have.been.calledWith( + fakeApiSettings, + TEMPLATE_ID, + { + inputs: { ...inputs }, + history: [] + }, + undefined + ); + }); + + it('should pass requestOptions to templateGenerateContent', async () => { + templateGenerateContentStub.resolves( + FAKE_MODEL_RESPONSE_1 as GenerateContentResult + ); + const requestOptions = { timeout: 5000 }; + const chat = new TemplateChatSession( + fakeApiSettings, + TEMPLATE_ID, + [], + requestOptions + ); + await chat.sendMessage('A request'); + expect(templateGenerateContentStub).to.have.been.calledWith( + fakeApiSettings, + TEMPLATE_ID, + { + inputs: {}, + history: [] + }, + requestOptions + ); + }); + }); + + describe('input variations for sendMessageStream', () => { + it('should handle request as a single string', async () => { + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) + }); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessageStream('Just a string'); + const history = await chat.getHistory(); + expect(history[0].parts[0].text).to.equal('Just a string'); + }); + + it('should handle request as an array of strings', async () => { + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) + }); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessageStream(['string 1', 'string 2']); + const history = await chat.getHistory(); + expect(history[0].parts).to.deep.equal([ + { text: 'string 1' }, + { text: 'string 2' } + ]); + }); + + it('should handle request as an array of Part objects', async () => { + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) + }); + const parts: Part[] = [{ text: 'part 1' }, { text: 'part 2' }]; + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + await chat.sendMessageStream(parts); + const history = await chat.getHistory(); + expect(history[0].parts).to.deep.equal(parts); + }); + + it('should pass inputs to templateGenerateContentStream', async () => { + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) + }); + const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); + const inputs = { someVar: 'someValue' }; + await chat.sendMessageStream('A request', inputs); + expect(templateGenerateContentStreamStub).to.have.been.calledWith( + fakeApiSettings, + TEMPLATE_ID, + { + inputs: { ...inputs }, + history: [] + } + ); + }); + + it('should pass requestOptions to templateGenerateContentStream', async () => { + templateGenerateContentStreamStub.resolves({ + stream: (async function* () {})(), + response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) + }); + const requestOptions = { timeout: 5000 }; + const chat = new TemplateChatSession( + fakeApiSettings, + TEMPLATE_ID, + [], + requestOptions + ); + await chat.sendMessageStream('A request'); + expect(templateGenerateContentStreamStub).to.have.been.calledWith( + fakeApiSettings, + TEMPLATE_ID, + { + inputs: {}, + history: [] + }, + requestOptions + ); + }); + }); +}); diff --git a/packages/ai/src/methods/template-chat-session.ts b/packages/ai/src/methods/template-chat-session.ts new file mode 100644 index 00000000000..8f646861de6 --- /dev/null +++ b/packages/ai/src/methods/template-chat-session.ts @@ -0,0 +1,230 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Content, + GenerateContentResult, + GenerateContentStreamResult, + Part, + RequestOptions +} from '../types'; +import { formatNewContent } from '../requests/request-helpers'; +import { formatBlockErrorMessage } from '../requests/response-helpers'; +import { validateChatHistory } from './chat-session-helpers'; +import { + templateGenerateContent, + templateGenerateContentStream +} from './generate-content'; +import { ApiSettings } from '../types/internal'; +import { logger } from '../logger'; + +/** + * Do not log a message for this error. + */ +const SILENT_ERROR = 'SILENT_ERROR'; + +/** + * A chat session that enables sending chat messages and stores the history of + * sent and received messages so far. + * + * This session is for multi-turn chats using a server-side template. It should + * be instantiated with {@link TemplateGenerativeModel.startChat}. + * + * @beta + */ +export class TemplateChatSession { + private _sendPromise: Promise = Promise.resolve(); + + /** + * @hideconstructor + */ + constructor( + private _apiSettings: ApiSettings, + public templateId: string, + private _history: Content[] = [], + public requestOptions?: RequestOptions + ) { + if (this._history) { + validateChatHistory(this._history); + } + } + + /** + * Gets the chat history so far. Blocked prompts are not added to history. + * Neither blocked candidates nor the prompts that generated them are added + * to history. + * + * @beta + */ + async getHistory(): Promise { + await this._sendPromise; + return this._history; + } + + /** + * Sends a chat message and receives a non-streaming + * {@link GenerateContentResult}. + * + * @param request - The user message to store in the history + * @param inputs - A key-value map of variables to populate the template + * with. This should likely include the user message. + * + * @beta + */ + async sendMessage( + request: string | Array, + inputs?: object + ): Promise { + await this._sendPromise; + let finalResult = {} as GenerateContentResult; + const variablesWithHistory = { + inputs: { + ...inputs + }, + history: [...this._history] + }; + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => + templateGenerateContent( + this._apiSettings, + this.templateId, + variablesWithHistory, + this.requestOptions + ) + ) + .then(result => { + if ( + result.response.candidates && + result.response.candidates.length > 0 + ) { + // Important note: The user's message is *not* the actual message that was sent to + // the model, but the message that was passed as a parameter. + // Since the real message was the rendered server prompt template, there is no way + // to store the actual message in the client. + // It's the user's responsibility to ensure that the `message` that goes in the history + // is as close as possible to the rendered template if they want a realistic chat + // experience. + // The ideal case here is that the user defines a `message` variable in the `inputs` of + // the prompt template. The other parts of the message that the prompt template is hiding + // isn't relevant to the conversation history. For example, system instructions. + // In this case, the user would have the user's `message` that they pass as the first + // argument to this method, then *also* pass that in the `inputs`, so that it's actually + // part of the populated template that is sent to the model. + this._history.push(formatNewContent(request)); + const responseContent: Content = { + parts: result.response.candidates?.[0].content.parts || [], + // Response seems to come back without a role set. + role: result.response.candidates?.[0].content.role || 'model' + }; + this._history.push(responseContent); + } else { + const blockErrorMessage = formatBlockErrorMessage(result.response); + if (blockErrorMessage) { + logger.warn( + `sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.` + ); + } + } + finalResult = result as GenerateContentResult; + }); + await this._sendPromise; + return finalResult; + } + + /** + * Sends a chat message and receives the response as a + * {@link GenerateContentStreamResult} containing an iterable stream + * and a response promise. + * + * @param request - The message to send to the model. + * @param inputs - A key-value map of variables to populate the template + * with. + * + * @beta + */ + async sendMessageStream( + request: string | Array, + inputs?: object + ): Promise { + await this._sendPromise; + const variablesWithHistory = { + inputs: { + ...inputs + }, + history: [...this._history] + }; + const streamPromise = templateGenerateContentStream( + this._apiSettings, + this.templateId, + variablesWithHistory, + this.requestOptions + ); + + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => streamPromise) + // This must be handled to avoid unhandled rejection, but jump + // to the final catch block with a label to not log this error. + .catch(_ignored => { + throw new Error(SILENT_ERROR); + }) + .then(streamResult => streamResult.response) + .then(response => { + if (response.candidates && response.candidates.length > 0) { + // Important note: The user's message is *not* the actual message that was sent to + // the model, but the message that was passed as a parameter. + // Since the real message was the rendered server prompt template, there is no way + // to store the actual message in the client. + // It's the user's responsibility to ensure that the `message` that goes in the history + // is as close as possible to the rendered template if they want a realistic chat + // experience. + // The ideal case here is that the user defines a `message` variable in the `inputs` of + // the prompt template. The other parts of the message that the prompt template is hiding + // isn't relevant to the conversation history. For example, system instructions. + // In this case, the user would have the user's `message` that they pass as the first + // argument to this method, then *also* pass that in the `inputs`, so that it's actually + // part of the populated template that is sent to the model. + this._history.push(formatNewContent(request)); + const responseContent = { ...response.candidates[0].content }; + // Response seems to come back without a role set. + if (!responseContent.role) { + responseContent.role = 'model'; + } + this._history.push(responseContent); + } else { + const blockErrorMessage = formatBlockErrorMessage(response); + if (blockErrorMessage) { + logger.warn( + `sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.` + ); + } + } + }) + .catch(e => { + // Errors in streamPromise are already catchable by the user as + // streamPromise is returned. + // Avoid duplicating the error message in logs. + if (e.message !== SILENT_ERROR) { + // Users do not have access to _sendPromise to catch errors + // downstream from streamPromise, so they should not throw. + logger.error(e); + } + }); + return streamPromise; + } +} diff --git a/packages/ai/src/models/ai-model.test.ts b/packages/ai/src/models/ai-model.test.ts index 2e8f8998c58..4786adc8546 100644 --- a/packages/ai/src/models/ai-model.test.ts +++ b/packages/ai/src/models/ai-model.test.ts @@ -15,13 +15,10 @@ * limitations under the License. */ import { use, expect } from 'chai'; -import { AI, AIErrorCode } from '../public-types'; +import { AI } from '../public-types'; import sinonChai from 'sinon-chai'; -import { stub } from 'sinon'; import { AIModel } from './ai-model'; -import { AIError } from '../errors'; import { VertexAIBackend } from '../backend'; -import { AIService } from '../service'; use(sinonChai); @@ -69,105 +66,4 @@ describe('AIModel', () => { const testModel = new TestModel(fakeAI, 'tunedModels/my-model'); expect(testModel.model).to.equal('tunedModels/my-model'); }); - it('calls regular app check token when option is set', async () => { - const getTokenStub = stub().resolves(); - const getLimitedUseTokenStub = stub().resolves(); - const testModel = new TestModel( - //@ts-ignore - { - ...fakeAI, - options: { useLimitedUseAppCheckTokens: false }, - appCheck: { - getToken: getTokenStub, - getLimitedUseToken: getLimitedUseTokenStub - } - } as AIService, - 'models/my-model' - ); - if (testModel._apiSettings?.getAppCheckToken) { - await testModel._apiSettings.getAppCheckToken(); - } - expect(getTokenStub).to.be.called; - expect(getLimitedUseTokenStub).to.not.be.called; - getTokenStub.reset(); - getLimitedUseTokenStub.reset(); - }); - it('calls limited use token when option is set', async () => { - const getTokenStub = stub().resolves(); - const getLimitedUseTokenStub = stub().resolves(); - const testModel = new TestModel( - //@ts-ignore - { - ...fakeAI, - options: { useLimitedUseAppCheckTokens: true }, - appCheck: { - getToken: getTokenStub, - getLimitedUseToken: getLimitedUseTokenStub - } - } as AIService, - 'models/my-model' - ); - if (testModel._apiSettings?.getAppCheckToken) { - await testModel._apiSettings.getAppCheckToken(); - } - expect(getTokenStub).to.not.be.called; - expect(getLimitedUseTokenStub).to.be.called; - getTokenStub.reset(); - getLimitedUseTokenStub.reset(); - }); - it('throws if not passed an api key', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - projectId: 'my-project' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_API_KEY); - } - }); - it('throws if not passed a project ID', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - apiKey: 'key' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_PROJECT_ID); - } - }); - it('throws if not passed an app ID', () => { - const fakeAI: AI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - apiKey: 'key', - projectId: 'my-project' - } - }, - backend: new VertexAIBackend('us-central1'), - location: 'us-central1' - }; - try { - new TestModel(fakeAI, 'my-model'); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.NO_APP_ID); - } - }); }); diff --git a/packages/ai/src/models/ai-model.ts b/packages/ai/src/models/ai-model.ts index 3fe202d5eb2..e2bc70319d8 100644 --- a/packages/ai/src/models/ai-model.ts +++ b/packages/ai/src/models/ai-model.ts @@ -15,11 +15,9 @@ * limitations under the License. */ -import { AIError } from '../errors'; -import { AIErrorCode, AI, BackendType } from '../public-types'; -import { AIService } from '../service'; +import { AI, BackendType } from '../public-types'; import { ApiSettings } from '../types/internal'; -import { _isFirebaseServerApp } from '@firebase/app'; +import { initApiSettings } from './utils'; /** * Base class for Firebase AI model APIs. @@ -59,56 +57,11 @@ export abstract class AIModel { * @internal */ protected constructor(ai: AI, modelName: string) { - if (!ai.app?.options?.apiKey) { - throw new AIError( - AIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.` - ); - } else if (!ai.app?.options?.projectId) { - throw new AIError( - AIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.` - ); - } else if (!ai.app?.options?.appId) { - throw new AIError( - AIErrorCode.NO_APP_ID, - `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.` - ); - } else { - this._apiSettings = { - apiKey: ai.app.options.apiKey, - project: ai.app.options.projectId, - appId: ai.app.options.appId, - automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, - location: ai.location, - backend: ai.backend - }; - - if (_isFirebaseServerApp(ai.app) && ai.app.settings.appCheckToken) { - const token = ai.app.settings.appCheckToken; - this._apiSettings.getAppCheckToken = () => { - return Promise.resolve({ token }); - }; - } else if ((ai as AIService).appCheck) { - if (ai.options?.useLimitedUseAppCheckTokens) { - this._apiSettings.getAppCheckToken = () => - (ai as AIService).appCheck!.getLimitedUseToken(); - } else { - this._apiSettings.getAppCheckToken = () => - (ai as AIService).appCheck!.getToken(); - } - } - - if ((ai as AIService).auth) { - this._apiSettings.getAuthToken = () => - (ai as AIService).auth!.getToken(); - } - - this.model = AIModel.normalizeModelName( - modelName, - this._apiSettings.backend.backendType - ); - } + this._apiSettings = initApiSettings(ai); + this.model = AIModel.normalizeModelName( + modelName, + this._apiSettings.backend.backendType + ); } /** diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index bcd78d746d4..f12eb7b75fc 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -97,10 +97,13 @@ describe('GenerativeModel', () => { ); await genModel.generateContent('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('myfunc') && @@ -109,8 +112,7 @@ describe('GenerativeModel', () => { value.includes(FunctionCallingMode.NONE) && value.includes('be friendly') ); - }), - {} + }) ); restore(); }); @@ -134,14 +136,16 @@ describe('GenerativeModel', () => { ); await genModel.generateContent('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return value.includes('be friendly'); - }), - {} + }) ); restore(); }); @@ -195,10 +199,13 @@ describe('GenerativeModel', () => { systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] } }); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('otherfunc') && @@ -207,8 +214,7 @@ describe('GenerativeModel', () => { value.includes(FunctionCallingMode.AUTO) && value.includes('be formal') ); - }), - {} + }) ); restore(); }); @@ -286,10 +292,13 @@ describe('GenerativeModel', () => { ); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('myfunc') && @@ -299,8 +308,7 @@ describe('GenerativeModel', () => { value.includes('be friendly') && value.includes('topK') ); - }), - {} + }) ); restore(); }); @@ -324,14 +332,16 @@ describe('GenerativeModel', () => { ); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return value.includes('be friendly'); - }), - {} + }) ); restore(); }); @@ -387,10 +397,13 @@ describe('GenerativeModel', () => { }) .sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + requestOptions: {} + }, match((value: string) => { return ( value.includes('otherfunc') && @@ -401,8 +414,7 @@ describe('GenerativeModel', () => { value.includes('image/png') && !value.includes('image/jpeg') ); - }), - {} + }) ); restore(); }); @@ -422,10 +434,13 @@ describe('GenerativeModel', () => { ); await genModel.countTokens('hello'); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.COUNT_TOKENS, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.COUNT_TOKENS, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return value.includes('hello'); }) diff --git a/packages/ai/src/models/imagen-model.test.ts b/packages/ai/src/models/imagen-model.test.ts index f4121e18f2d..68b6caca098 100644 --- a/packages/ai/src/models/imagen-model.test.ts +++ b/packages/ai/src/models/imagen-model.test.ts @@ -62,17 +62,19 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return ( value.includes(`"prompt":"${prompt}"`) && value.includes(`"sampleCount":1`) ); - }), - undefined + }) ); restore(); }); @@ -102,10 +104,13 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).to.be.calledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - match.any, - false, + { + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + requestOptions: undefined + }, match((value: string) => { return ( value.includes( @@ -130,8 +135,7 @@ describe('ImagenModel', () => { JSON.stringify(imagenModel.safetySettings?.personFilterLevel) ) ); - }), - undefined + }) ); restore(); }); diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index a41a03f25cf..820defb62a7 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -109,12 +109,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, - JSON.stringify(body), - this.requestOptions + { + task: Task.PREDICT, + model: this.model, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify(body) ); return handlePredictResponse(response); } @@ -148,12 +150,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, - JSON.stringify(body), - this.requestOptions + { + task: Task.PREDICT, + model: this.model, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify(body) ); return handlePredictResponse(response); } diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts new file mode 100644 index 00000000000..301be8a2504 --- /dev/null +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import sinonChai from 'sinon-chai'; +import { restore, stub } from 'sinon'; +import { AI, Content } from '../public-types'; +import { VertexAIBackend } from '../backend'; +import { TemplateGenerativeModel } from './template-generative-model'; +import * as generateContentMethods from '../methods/generate-content'; +import { TemplateChatSession } from '../methods/template-chat-session'; + +use(sinonChai); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +const TEMPLATE_ID = 'my-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateGenerativeModel', () => { + afterEach(() => { + restore(); + }); + + describe('constructor', () => { + it('should initialize _apiSettings correctly', () => { + const model = new TemplateGenerativeModel(fakeAI); + expect(model._apiSettings.apiKey).to.equal('key'); + expect(model._apiSettings.project).to.equal('my-project'); + expect(model._apiSettings.appId).to.equal('my-appid'); + }); + }); + + describe('generateContent', () => { + it('should call templateGenerateContent with correct parameters', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContent(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 } + ); + }); + }); + + describe('generateContentStream', () => { + it('should call templateGenerateContentStream with correct parameters', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContentStream(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 } + ); + }); + }); + + describe('startChat', () => { + it('should return a TemplateChatSession instance', () => { + const model = new TemplateGenerativeModel(fakeAI); + const chat = model.startChat(TEMPLATE_ID); + expect(chat).to.be.an.instanceOf(TemplateChatSession); + expect(chat.templateId).to.equal(TEMPLATE_ID); + }); + + it('should pass history and requestOptions to TemplateChatSession', () => { + const history: Content[] = [{ role: 'user', parts: [{ text: 'hi' }] }]; + const requestOptions = { timeout: 1000 }; + const model = new TemplateGenerativeModel(fakeAI, requestOptions); + const chat = model.startChat(TEMPLATE_ID, history); + + expect(chat.requestOptions).to.deep.equal(requestOptions); + // Private property, but we can check it for test purposes + expect((chat as any)._history).to.deep.equal(history); + }); + }); +}); diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts new file mode 100644 index 00000000000..e2d3b71736a --- /dev/null +++ b/packages/ai/src/models/template-generative-model.ts @@ -0,0 +1,118 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + templateGenerateContent, + templateGenerateContentStream +} from '../methods/generate-content'; +import { GenerateContentResult, RequestOptions } from '../types'; +import { AI, Content, GenerateContentStreamResult } from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { TemplateChatSession } from '../methods/template-chat-session'; +import { initApiSettings } from './utils'; + +/** + * {@link GenerativeModel} APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateGenerativeModel}. + * + * @beta + */ +export class TemplateGenerativeModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single non-streaming call to the model and returns an object + * containing a single {@link GenerateContentResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContent( + templateId: string, + templateVariables: object // anything! + ): Promise { + return templateGenerateContent( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions + ); + } + + /** + * Makes a single streaming call to the model and returns an object + * containing an iterable stream that iterates over all chunks in the + * streaming response as well as a promise that returns the final aggregated + * response. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContentStream( + templateId: string, + templateVariables: object + ): Promise { + return templateGenerateContentStream( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions + ); + } + + /** + * Gets a new {@link TemplateChatSession} instance which can be used for + * multi-turn chats. + * + * @param templateId - The ID of the server-side template to execute. + * @param history - An array of {@link Content} objects to initialize the + * chat history with. + * + * @beta + */ + startChat(templateId: string, history?: Content[]): TemplateChatSession { + return new TemplateChatSession( + this._apiSettings, + templateId, + history, + this.requestOptions + ); + } +} diff --git a/packages/ai/src/models/template-imagen-model.test.ts b/packages/ai/src/models/template-imagen-model.test.ts new file mode 100644 index 00000000000..d53f0351dd2 --- /dev/null +++ b/packages/ai/src/models/template-imagen-model.test.ts @@ -0,0 +1,139 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { restore, stub } from 'sinon'; +import { AI } from '../public-types'; +import { VertexAIBackend } from '../backend'; +import { TemplateImagenModel } from './template-imagen-model'; +import { AIError } from '../errors'; +import * as request from '../requests/request'; + +use(sinonChai); +use(chaiAsPromised); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +const TEMPLATE_ID = 'my-imagen-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateImagenModel', () => { + afterEach(() => { + restore(); + }); + + describe('constructor', () => { + it('should initialize _apiSettings correctly', () => { + const model = new TemplateImagenModel(fakeAI); + expect(model._apiSettings.apiKey).to.equal('key'); + expect(model._apiSettings.project).to.equal('my-project'); + expect(model._apiSettings.appId).to.equal('my-appid'); + }); + }); + + describe('generateImages', () => { + it('should call makeRequest with correct parameters', async () => { + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => + Promise.resolve({ + predictions: [ + { + bytesBase64Encoded: + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + mimeType: 'image/png' + } + ] + }) + } as Response); + const model = new TemplateImagenModel(fakeAI, { timeout: 5000 }); + + await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(makeRequestStub).to.have.been.calledOnceWith( + { + task: 'templatePredict', + templateId: TEMPLATE_ID, + apiSettings: model._apiSettings, + stream: false, + requestOptions: { timeout: 5000 } + }, + JSON.stringify({ inputs: TEMPLATE_VARS }) + ); + }); + + it('should return the result of handlePredictResponse', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).to.deep.equal([mockPrediction]); + }); + + it('should throw an AIError if the prompt is blocked', async () => { + const error = new AIError('fetch-error', 'Request failed'); + stub(request, 'makeRequest').rejects(error); + + const model = new TemplateImagenModel(fakeAI); + await expect( + model.generateImages(TEMPLATE_ID, TEMPLATE_VARS) + ).to.be.rejectedWith(error); + }); + + it('should handle responses with filtered images', async () => { + const mockPrediction = { + bytesBase64Encoded: 'iVBOR...ggg==', + mimeType: 'image/png' + }; + const filteredReason = 'This image was filtered for safety reasons.'; + stub(request, 'makeRequest').resolves({ + json: () => + Promise.resolve({ + predictions: [mockPrediction, { raiFilteredReason: filteredReason }] + }) + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).to.have.lengthOf(1); + expect(result.images[0]).to.deep.equal(mockPrediction); + expect(result.filteredReason).to.equal(filteredReason); + }); + }); +}); diff --git a/packages/ai/src/models/template-imagen-model.ts b/packages/ai/src/models/template-imagen-model.ts new file mode 100644 index 00000000000..ba1919f1140 --- /dev/null +++ b/packages/ai/src/models/template-imagen-model.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { RequestOptions } from '../types'; +import { + AI, + ImagenGenerationResponse, + ImagenInlineImage +} from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { makeRequest } from '../requests/request'; +import { handlePredictResponse } from '../requests/response-helpers'; +import { initApiSettings } from './utils'; + +/** + * Class for Imagen model APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateImagenModel}. + * + * @beta + */ +export class TemplateImagenModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single call to the model and returns an object containing a single + * {@link ImagenGenerationResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateImages( + templateId: string, + templateVariables: object + ): Promise> { + const response = await makeRequest( + { + task: 'templatePredict', + templateId, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions + }, + JSON.stringify({ inputs: templateVariables }) + ); + return handlePredictResponse(response); + } +} diff --git a/packages/ai/src/models/utils.test.ts b/packages/ai/src/models/utils.test.ts new file mode 100644 index 00000000000..42d19007275 --- /dev/null +++ b/packages/ai/src/models/utils.test.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { use, expect } from 'chai'; +import { AI, AIErrorCode } from '../public-types'; +import sinonChai from 'sinon-chai'; +import { stub } from 'sinon'; +import { AIError } from '../errors'; +import { VertexAIBackend } from '../backend'; +import { AIService } from '../service'; +import { initApiSettings } from './utils'; + +use(sinonChai); + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' +}; + +describe('initApiSettings', () => { + it('calls regular app check token when option is set', async () => { + const getTokenStub = stub().resolves(); + const getLimitedUseTokenStub = stub().resolves(); + const apiSettings = initApiSettings( + //@ts-ignore + { + ...fakeAI, + options: { useLimitedUseAppCheckTokens: false }, + appCheck: { + getToken: getTokenStub, + getLimitedUseToken: getLimitedUseTokenStub + } + } as AIService + ); + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + expect(getTokenStub).to.be.called; + expect(getLimitedUseTokenStub).to.not.be.called; + getTokenStub.reset(); + getLimitedUseTokenStub.reset(); + }); + it('calls limited use token when option is set', async () => { + const getTokenStub = stub().resolves(); + const getLimitedUseTokenStub = stub().resolves(); + const apiSettings = initApiSettings( + //@ts-ignore + { + ...fakeAI, + options: { useLimitedUseAppCheckTokens: true }, + appCheck: { + getToken: getTokenStub, + getLimitedUseToken: getLimitedUseTokenStub + } + } as AIService + ); + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + expect(getTokenStub).to.not.be.called; + expect(getLimitedUseTokenStub).to.be.called; + getTokenStub.reset(); + getLimitedUseTokenStub.reset(); + }); + it('throws if not passed an api key', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + projectId: 'my-project' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_API_KEY); + } + }); + it('throws if not passed a project ID', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_PROJECT_ID); + } + }); + it('throws if not passed an app ID', () => { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project' + } + }, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1' + }; + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.NO_APP_ID); + } + }); +}); diff --git a/packages/ai/src/models/utils.ts b/packages/ai/src/models/utils.ts new file mode 100644 index 00000000000..ffe91d40278 --- /dev/null +++ b/packages/ai/src/models/utils.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { _isFirebaseServerApp } from '@firebase/app'; +import { AIError } from '../errors'; +import { AI, AIErrorCode } from '../public-types'; +import { AIService } from '../service'; +import { ApiSettings } from '../types/internal'; + +export function initApiSettings(ai: AI): ApiSettings { + if (!ai.app?.options?.apiKey) { + throw new AIError( + AIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.` + ); + } else if (!ai.app?.options?.projectId) { + throw new AIError( + AIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.` + ); + } else if (!ai.app?.options?.appId) { + throw new AIError( + AIErrorCode.NO_APP_ID, + `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.` + ); + } + + const apiSettings: ApiSettings = { + apiKey: ai.app.options.apiKey, + project: ai.app.options.projectId, + appId: ai.app.options.appId, + automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, + location: ai.location, + backend: ai.backend + }; + + if (_isFirebaseServerApp(ai.app) && ai.app.settings.appCheckToken) { + const token = ai.app.settings.appCheckToken; + apiSettings.getAppCheckToken = () => { + return Promise.resolve({ token }); + }; + } else if ((ai as AIService).appCheck) { + if (ai.options?.useLimitedUseAppCheckTokens) { + apiSettings.getAppCheckToken = () => + (ai as AIService).appCheck!.getLimitedUseToken(); + } else { + apiSettings.getAppCheckToken = () => + (ai as AIService).appCheck!.getToken(); + } + } + + if ((ai as AIService).auth) { + apiSettings.getAuthToken = () => (ai as AIService).auth!.getToken(); + } + + return apiSettings; +} diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index 0d162906fdc..bafb97bf855 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -19,7 +19,7 @@ import { expect, use } from 'chai'; import { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { RequestUrl, Task, getHeaders, makeRequest } from './request'; +import { RequestURL, Task, getHeaders, makeRequest } from './request'; import { ApiSettings } from '../types/internal'; import { DEFAULT_API_VERSION } from '../constants'; import { AIErrorCode } from '../types'; @@ -42,65 +42,77 @@ describe('request methods', () => { afterEach(() => { restore(); }); - describe('RequestUrl', () => { + describe('RequestURL', () => { it('stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); expect(url.toString()).to.include('models/model-name:generateContent'); - expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.include('alt=sse'); }); it('non-stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.not.include('alt=sse'); }); it('default apiVersion', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include(DEFAULT_API_VERSION); }); it('custom baseUrl', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - { baseUrl: 'https://my.special.endpoint' } - ); + const url = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { baseUrl: 'https://my.special.endpoint' } + }); expect(url.toString()).to.include('https://my.special.endpoint'); }); it('non-stream - tunedModels/', async () => { - const url = new RequestUrl( - 'tunedModels/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {} - ); + const url = new RequestURL({ + model: 'tunedModels/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); expect(url.toString()).to.include( 'tunedModels/model-name:generateContent' ); expect(url.toString()).to.not.include(fakeApiSettings); expect(url.toString()).to.not.include('alt=sse'); }); + it('prompt server template', async () => { + const url = new RequestURL({ + templateId: 'my-template', + task: 'templateGenerateContent', + apiSettings: fakeApiSettings, + stream: false, + requestOptions: {} + }); + expect(url.toString()).to.include( + 'templates/my-template:templateGenerateContent' + ); + expect(url.toString()).to.not.include(fakeApiSettings); + }); }); describe('getHeaders', () => { const fakeApiSettings: ApiSettings = { @@ -112,13 +124,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); it('adds client headers', async () => { const headers = await getHeaders(fakeUrl); expect(headers.get('x-goog-api-client')).to.match( @@ -140,13 +152,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.equal('my-appid'); }); @@ -165,13 +177,13 @@ describe('request methods', () => { getAuthToken: () => Promise.resolve({ accessToken: 'authtoken' }), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }) }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {} - ); + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.be.null; }); @@ -180,44 +192,44 @@ describe('request methods', () => { expect(headers.get('X-Firebase-AppCheck')).to.equal('appchecktoken'); }); it('ignores app check token header if no appcheck service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', location: 'moon', backend: new VertexAIBackend() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; }); it('ignores app check token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', location: 'moon', //@ts-ignore getAppCheckToken: () => Promise.resolve() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; }); it('ignores app check token header if returned token had error', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', @@ -226,9 +238,9 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve({ token: 'dummytoken', error: Error('oops') }) }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const warnStub = stub(console, 'warn'); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-AppCheck')).to.equal('dummytoken'); @@ -242,36 +254,36 @@ describe('request methods', () => { expect(headers.get('Authorization')).to.equal('Firebase authtoken'); }); it('ignores auth token header if no auth service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', appId: 'my-appid', location: 'moon', backend: new VertexAIBackend() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); it('ignores auth token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { + const fakeUrl = new RequestURL({ + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: { apiKey: 'key', project: 'myproject', location: 'moon', //@ts-ignore getAppCheckToken: () => Promise.resolve() }, - true, - {} - ); + stream: true, + requestOptions: {} + }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); @@ -282,10 +294,12 @@ describe('request methods', () => { ok: true } as Response); const response = await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); expect(fetchStub).to.be.calledOnce; @@ -300,14 +314,16 @@ describe('request methods', () => { try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - '', { - timeout: 180000 - } + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { + timeout: 180000 + } + }, + '' ); } catch (e) { expect((e as AIError).code).to.equal(AIErrorCode.FETCH_ERROR); @@ -328,10 +344,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -353,10 +371,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -391,10 +411,12 @@ describe('request methods', () => { } as Response); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { @@ -420,10 +442,12 @@ describe('request methods', () => { ); try { await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, '' ); } catch (e) { diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 90195b4b788..69b84c79205 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,14 +19,12 @@ import { ErrorDetails, RequestOptions, AIErrorCode } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { - DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION } from '../constants'; import { logger } from '../logger'; -import { GoogleAIBackend, VertexAIBackend } from '../backend'; import { BackendType } from '../public-types'; export enum Task { @@ -36,45 +34,71 @@ export enum Task { PREDICT = 'predict' } -export class RequestUrl { +export type ServerPromptTemplateTask = + | 'templateGenerateContent' + | 'templateStreamGenerateContent' + | 'templatePredict'; + +interface BaseRequestURLParams { + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; +} + +/** + * Parameters used to construct the URL of a request to use a model. + */ +interface ModelRequestURLParams extends BaseRequestURLParams { + task: Task; + model: string; + templateId?: never; +} + +/** + * Parameters used to construct the URL of a request to use server side prompt templates. + */ +interface TemplateRequestURLParams extends BaseRequestURLParams { + task: ServerPromptTemplateTask; + templateId: string; + model?: never; +} + +export class RequestURL { constructor( - public model: string, - public task: Task, - public apiSettings: ApiSettings, - public stream: boolean, - public requestOptions?: RequestOptions + public readonly params: ModelRequestURLParams | TemplateRequestURLParams ) {} + toString(): string { const url = new URL(this.baseUrl); // Throws if the URL is invalid - url.pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`; + url.pathname = this.pathname; url.search = this.queryParams.toString(); return url.toString(); } - private get baseUrl(): string { - return this.requestOptions?.baseUrl || `https://${DEFAULT_DOMAIN}`; - } - - private get apiVersion(): string { - return DEFAULT_API_VERSION; // TODO: allow user-set options if that feature becomes available - } - - private get modelPath(): string { - if (this.apiSettings.backend instanceof GoogleAIBackend) { - return `projects/${this.apiSettings.project}/${this.model}`; - } else if (this.apiSettings.backend instanceof VertexAIBackend) { - return `projects/${this.apiSettings.project}/locations/${this.apiSettings.backend.location}/${this.model}`; + private get pathname(): string { + // We need to construct a different URL if the request is for server side prompt templates, + // since the URL patterns are different. Server side prompt templates expect a templateId + // instead of a model name. + if (this.params.templateId) { + return `${this.params.apiSettings.backend._getTemplatePath( + this.params.apiSettings.project, + this.params.templateId + )}:${this.params.task}`; } else { - throw new AIError( - AIErrorCode.ERROR, - `Invalid backend: ${JSON.stringify(this.apiSettings.backend)}` - ); + return `${this.params.apiSettings.backend._getModelPath( + this.params.apiSettings.project, + (this.params as ModelRequestURLParams).model + )}:${this.params.task}`; } } + private get baseUrl(): string { + return this.params.requestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}`; + } + private get queryParams(): URLSearchParams { const params = new URLSearchParams(); - if (this.stream) { + if (this.params.stream) { params.set('alt', 'sse'); } @@ -114,16 +138,16 @@ function getClientHeaders(): string { return loggingTags.join(' '); } -export async function getHeaders(url: RequestUrl): Promise { +export async function getHeaders(url: RequestURL): Promise { const headers = new Headers(); headers.append('Content-Type', 'application/json'); headers.append('x-goog-api-client', getClientHeaders()); - headers.append('x-goog-api-key', url.apiSettings.apiKey); - if (url.apiSettings.automaticDataCollectionEnabled) { - headers.append('X-Firebase-Appid', url.apiSettings.appId); + headers.append('x-goog-api-key', url.params.apiSettings.apiKey); + if (url.params.apiSettings.automaticDataCollectionEnabled) { + headers.append('X-Firebase-Appid', url.params.apiSettings.appId); } - if (url.apiSettings.getAppCheckToken) { - const appCheckToken = await url.apiSettings.getAppCheckToken(); + if (url.params.apiSettings.getAppCheckToken) { + const appCheckToken = await url.params.apiSettings.getAppCheckToken(); if (appCheckToken) { headers.append('X-Firebase-AppCheck', appCheckToken.token); if (appCheckToken.error) { @@ -134,8 +158,8 @@ export async function getHeaders(url: RequestUrl): Promise { } } - if (url.apiSettings.getAuthToken) { - const authToken = await url.apiSettings.getAuthToken(); + if (url.params.apiSettings.getAuthToken) { + const authToken = await url.params.apiSettings.getAuthToken(); if (authToken) { headers.append('Authorization', `Firebase ${authToken.accessToken}`); } @@ -144,55 +168,31 @@ export async function getHeaders(url: RequestUrl): Promise { return headers; } -export async function constructRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions -): Promise<{ url: string; fetchOptions: RequestInit }> { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); - return { - url: url.toString(), - fetchOptions: { - method: 'POST', - headers: await getHeaders(url), - body - } - }; -} - export async function makeRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions + requestUrlParams: TemplateRequestURLParams | ModelRequestURLParams, + body: string ): Promise { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + const url = new RequestURL(requestUrlParams); let response; let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; try { - const request = await constructRequest( - model, - task, - apiSettings, - stream, - body, - requestOptions - ); - // Timeout is 180s by default + const fetchOptions: RequestInit = { + method: 'POST', + headers: await getHeaders(url), + body + }; + + // Timeout is 180s by default. const timeoutMillis = - requestOptions?.timeout != null && requestOptions.timeout >= 0 - ? requestOptions.timeout + requestUrlParams.requestOptions?.timeout != null && + requestUrlParams.requestOptions.timeout >= 0 + ? requestUrlParams.requestOptions.timeout : DEFAULT_FETCH_TIMEOUT_MS; const abortController = new AbortController(); fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - request.fetchOptions.signal = abortController.signal; + fetchOptions.signal = abortController.signal; - response = await fetch(request.url, request.fetchOptions); + response = await fetch(url.toString(), fetchOptions); if (!response.ok) { let message = ''; let errorDetails; @@ -225,7 +225,7 @@ export async function makeRequest( `The Firebase AI SDK requires the Firebase AI ` + `API ('firebasevertexai.googleapis.com') to be enabled in your ` + `Firebase project. Enable this API by visiting the Firebase Console ` + - `at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` + + `at https://console.firebase.google.com/project/${url.params.apiSettings.project}/genai/ ` + `and clicking "Get started". If you enabled this API recently, ` + `wait a few minutes for the action to propagate to our systems and ` + `then retry.`, From 7ede7997264d3ce23221d5d94c64c274a606bef6 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Tue, 28 Oct 2025 11:01:54 -0400 Subject: [PATCH 2/8] Update changeset --- .changeset/metal-ties-cry.md | 1 + 1 file changed, 1 insertion(+) diff --git a/.changeset/metal-ties-cry.md b/.changeset/metal-ties-cry.md index f6c28a73b61..ffff494e6d2 100644 --- a/.changeset/metal-ties-cry.md +++ b/.changeset/metal-ties-cry.md @@ -1,4 +1,5 @@ --- +'firebase': minor '@firebase/ai': minor --- From 40aa3cfb1f87efb11b9e60d16ec20f8ee7c83b81 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 30 Oct 2025 10:26:18 -0400 Subject: [PATCH 3/8] revert node test fixes --- packages/ai/src/methods/chrome-adapter-browser.test.ts | 6 +----- packages/ai/src/methods/chrome-adapter.ts | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/packages/ai/src/methods/chrome-adapter-browser.test.ts b/packages/ai/src/methods/chrome-adapter-browser.test.ts index 1b851c3083c..e37a08bf1a9 100644 --- a/packages/ai/src/methods/chrome-adapter-browser.test.ts +++ b/packages/ai/src/methods/chrome-adapter-browser.test.ts @@ -29,7 +29,6 @@ import { import { match, stub } from 'sinon'; import { GenerateContentRequest, AIErrorCode, InferenceMode } from '../types'; import { Schema } from '../api'; -import { isNode } from '@firebase/util'; use(sinonChai); use(chaiAsPromised); @@ -54,9 +53,6 @@ async function toStringArray( } describe('ChromeAdapter', () => { - if (isNode()) { - return; - } describe('constructor', () => { it('sets image as expected input type by default', async () => { const languageModelProvider = { @@ -837,7 +833,7 @@ describe('chromeAdapterFactory', () => { const fakeLanguageModel = {} as LanguageModel; const adapter = chromeAdapterFactory( InferenceMode.PREFER_ON_DEVICE, - { LanguageModel: fakeLanguageModel } as any, + { LanguageModel: fakeLanguageModel } as Window, { createOptions: {} } ); expect(adapter?.languageModelProvider).to.equal(fakeLanguageModel); diff --git a/packages/ai/src/methods/chrome-adapter.ts b/packages/ai/src/methods/chrome-adapter.ts index 709084638c5..839276814bb 100644 --- a/packages/ai/src/methods/chrome-adapter.ts +++ b/packages/ai/src/methods/chrome-adapter.ts @@ -400,8 +400,7 @@ export function chromeAdapterFactory( // Do not initialize a ChromeAdapter if we are not in hybrid mode. if (typeof window !== 'undefined' && mode) { return new ChromeAdapterImpl( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (window as any).LanguageModel as LanguageModel, + (window as Window).LanguageModel as LanguageModel, mode, params ); From 2a4b9a914eba023ad8dd9aa6554e90eeb43337a7 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 30 Oct 2025 11:14:35 -0400 Subject: [PATCH 4/8] Review fixes --- packages/ai/integration/prompt-templates.test.ts | 5 +++-- packages/ai/src/constants.ts | 3 +++ packages/ai/src/methods/count-tokens.test.ts | 3 +++ packages/ai/src/methods/count-tokens.ts | 2 +- packages/ai/src/methods/generate-content.ts | 10 +++++++--- packages/ai/src/methods/template-chat-session.test.ts | 2 +- packages/ai/src/methods/template-chat-session.ts | 2 +- packages/ai/src/models/imagen-model.ts | 2 +- .../ai/src/models/template-generative-model.test.ts | 2 +- packages/ai/src/models/template-generative-model.ts | 2 +- packages/ai/src/models/template-imagen-model.test.ts | 2 +- packages/ai/src/models/template-imagen-model.ts | 6 +++--- packages/ai/src/models/utils.ts | 7 +++++++ packages/ai/src/requests/request.ts | 11 ++++++----- 14 files changed, 39 insertions(+), 20 deletions(-) diff --git a/packages/ai/integration/prompt-templates.test.ts b/packages/ai/integration/prompt-templates.test.ts index 4898bdf1c1f..3a7f9038561 100644 --- a/packages/ai/integration/prompt-templates.test.ts +++ b/packages/ai/integration/prompt-templates.test.ts @@ -22,6 +22,7 @@ import { getTemplateImagenModel } from '../src'; import { testConfigs } from './constants'; +import { STAGING_URL } from '../src/constants'; const templateBackendSuffix = ( backendType: BackendType @@ -35,7 +36,7 @@ describe('Prompt templates', function () { describe('Generative Model', () => { it('successfully generates content', async () => { const model = getTemplateGenerativeModel(testConfig.ai, { - baseUrl: 'https://staging-firebasevertexai.sandbox.googleapis.com' + baseUrl: STAGING_URL }); const { response } = await model.generateContent( `sassy-greeting-${templateBackendSuffix( @@ -49,7 +50,7 @@ describe('Prompt templates', function () { describe('Imagen model', async () => { it('successfully generates images', async () => { const model = getTemplateImagenModel(testConfig.ai, { - baseUrl: 'https://staging-firebasevertexai.sandbox.googleapis.com' + baseUrl: STAGING_URL }); const { images } = await model.generateImages( `portrait-${templateBackendSuffix( diff --git a/packages/ai/src/constants.ts b/packages/ai/src/constants.ts index 82482527f3b..0a6f7e91436 100644 --- a/packages/ai/src/constants.ts +++ b/packages/ai/src/constants.ts @@ -23,6 +23,9 @@ export const DEFAULT_LOCATION = 'us-central1'; export const DEFAULT_DOMAIN = 'firebasevertexai.googleapis.com'; +export const STAGING_URL = + 'https://staging-firebasevertexai.sandbox.googleapis.com'; + export const DEFAULT_API_VERSION = 'v1beta'; export const PACKAGE_VERSION = version; diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index c12b0989baf..b3ed7f7fa4d 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -99,6 +99,9 @@ describe('countTokens()', () => { fakeChromeAdapter ); expect(result.totalTokens).to.equal(1837); + expect(result.totalBillableCharacters).to.equal(117); + expect(result.promptTokensDetails?.[0].modality).to.equal('IMAGE'); + expect(result.promptTokensDetails?.[0].tokenCount).to.equal(1806); expect(makeRequestStub).to.be.calledWith( { model: 'model', diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index af496de4673..20c633ee703 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -23,7 +23,7 @@ import { RequestOptions, AIErrorCode } from '../types'; -import { Task, makeRequest } from '../requests/request'; +import { makeRequest, Task } from '../requests/request'; import { ApiSettings } from '../types/internal'; import * as GoogleAIMapper from '../googleai-mappers'; import { BackendType } from '../public-types'; diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index cffc4b48413..fc6eac15c74 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -22,7 +22,11 @@ import { GenerateContentStreamResult, RequestOptions } from '../types'; -import { Task, makeRequest } from '../requests/request'; +import { + makeRequest, + ServerPromptTemplateTask, + Task +} from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; @@ -98,7 +102,7 @@ export async function templateGenerateContent( ): Promise { const response = await makeRequest( { - task: 'templateGenerateContent', + task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, templateId, apiSettings, stream: false, @@ -126,7 +130,7 @@ export async function templateGenerateContentStream( ): Promise { const response = await makeRequest( { - task: 'templateStreamGenerateContent', + task: ServerPromptTemplateTask.TEMPLATE_STREAM_GENERATE_CONTENT, templateId, apiSettings, stream: true, diff --git a/packages/ai/src/methods/template-chat-session.test.ts b/packages/ai/src/methods/template-chat-session.test.ts index 79161143eba..2cbcb2f3ab1 100644 --- a/packages/ai/src/methods/template-chat-session.test.ts +++ b/packages/ai/src/methods/template-chat-session.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/packages/ai/src/methods/template-chat-session.ts b/packages/ai/src/methods/template-chat-session.ts index 8f646861de6..7c8e4a21252 100644 --- a/packages/ai/src/methods/template-chat-session.ts +++ b/packages/ai/src/methods/template-chat-session.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index 820defb62a7..567333ee64f 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -16,7 +16,7 @@ */ import { AI } from '../public-types'; -import { Task, makeRequest } from '../requests/request'; +import { makeRequest, Task } from '../requests/request'; import { createPredictRequestBody } from '../requests/request-helpers'; import { handlePredictResponse } from '../requests/response-helpers'; import { diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts index 301be8a2504..4198090eab6 100644 --- a/packages/ai/src/models/template-generative-model.test.ts +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts index e2d3b71736a..8a60728df2c 100644 --- a/packages/ai/src/models/template-generative-model.ts +++ b/packages/ai/src/models/template-generative-model.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/packages/ai/src/models/template-imagen-model.test.ts b/packages/ai/src/models/template-imagen-model.test.ts index d53f0351dd2..c053753ea0f 100644 --- a/packages/ai/src/models/template-imagen-model.test.ts +++ b/packages/ai/src/models/template-imagen-model.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/packages/ai/src/models/template-imagen-model.ts b/packages/ai/src/models/template-imagen-model.ts index ba1919f1140..34325c711b3 100644 --- a/packages/ai/src/models/template-imagen-model.ts +++ b/packages/ai/src/models/template-imagen-model.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2024 Google LLC + * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import { ImagenInlineImage } from '../public-types'; import { ApiSettings } from '../types/internal'; -import { makeRequest } from '../requests/request'; +import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; import { handlePredictResponse } from '../requests/response-helpers'; import { initApiSettings } from './utils'; @@ -68,7 +68,7 @@ export class TemplateImagenModel { ): Promise> { const response = await makeRequest( { - task: 'templatePredict', + task: ServerPromptTemplateTask.TEMPLATE_PREDICT, templateId, apiSettings: this._apiSettings, stream: false, diff --git a/packages/ai/src/models/utils.ts b/packages/ai/src/models/utils.ts index ffe91d40278..035ed3f734d 100644 --- a/packages/ai/src/models/utils.ts +++ b/packages/ai/src/models/utils.ts @@ -21,6 +21,13 @@ import { AI, AIErrorCode } from '../public-types'; import { AIService } from '../service'; import { ApiSettings } from '../types/internal'; +/** + * Initializes an {@link ApiSettings} object from an {@link AI} instance. + * + * If this is a Server App, the {@link ApiSettings} object's `getAppCheckToken()` will resolve + * with the `FirebaseServerAppSettings.appCheckToken`, instead of requiring that an App Check + * instance is initialized. + */ export function initApiSettings(ai: AI): ApiSettings { if (!ai.app?.options?.apiKey) { throw new AIError( diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 69b84c79205..7664765ab03 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -27,17 +27,18 @@ import { import { logger } from '../logger'; import { BackendType } from '../public-types'; -export enum Task { +export const enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', COUNT_TOKENS = 'countTokens', PREDICT = 'predict' } -export type ServerPromptTemplateTask = - | 'templateGenerateContent' - | 'templateStreamGenerateContent' - | 'templatePredict'; +export const enum ServerPromptTemplateTask { + TEMPLATE_GENERATE_CONTENT = 'templateGenerateContent', + TEMPLATE_STREAM_GENERATE_CONTENT = 'templateStreamGenerateContent', + TEMPLATE_PREDICT = 'templatePredict' +} interface BaseRequestURLParams { apiSettings: ApiSettings; From 503fa59da95df77504f4c0d6ea769c84b442c959 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 30 Oct 2025 11:16:40 -0400 Subject: [PATCH 5/8] Remove Chat --- common/api-review/ai.api.md | 13 - packages/ai/src/api.ts | 1 - .../src/methods/template-chat-session.test.ts | 359 ------------------ .../ai/src/methods/template-chat-session.ts | 230 ----------- .../models/template-generative-model.test.ts | 23 +- .../src/models/template-generative-model.ts | 22 +- 6 files changed, 2 insertions(+), 646 deletions(-) delete mode 100644 packages/ai/src/methods/template-chat-session.test.ts delete mode 100644 packages/ai/src/methods/template-chat-session.ts diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 0edce6bac15..4057a85bb9a 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -1328,18 +1328,6 @@ export class StringSchema extends Schema { toJSON(): SchemaRequest; } -// @beta -export class TemplateChatSession { - constructor(_apiSettings: ApiSettings, templateId: string, _history?: Content[], requestOptions?: RequestOptions | undefined); - getHistory(): Promise; - // (undocumented) - requestOptions?: RequestOptions | undefined; - sendMessage(request: string | Array, inputs?: object): Promise; - sendMessageStream(request: string | Array, inputs?: object): Promise; - // (undocumented) - templateId: string; -} - // @beta export class TemplateGenerativeModel { constructor(ai: AI, requestOptions?: RequestOptions); @@ -1348,7 +1336,6 @@ export class TemplateGenerativeModel { generateContent(templateId: string, templateVariables: object): Promise; generateContentStream(templateId: string, templateVariables: object): Promise; requestOptions?: RequestOptions; - startChat(templateId: string, history?: Content[]): TemplateChatSession; } // @beta diff --git a/packages/ai/src/api.ts b/packages/ai/src/api.ts index 4f9201a52fd..29614d88cec 100644 --- a/packages/ai/src/api.ts +++ b/packages/ai/src/api.ts @@ -43,7 +43,6 @@ import { TemplateGenerativeModel } from './models/template-generative-model'; import { TemplateImagenModel } from './models/template-imagen-model'; export { ChatSession } from './methods/chat-session'; -export { TemplateChatSession } from './methods/template-chat-session'; export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; diff --git a/packages/ai/src/methods/template-chat-session.test.ts b/packages/ai/src/methods/template-chat-session.test.ts deleted file mode 100644 index 2cbcb2f3ab1..00000000000 --- a/packages/ai/src/methods/template-chat-session.test.ts +++ /dev/null @@ -1,359 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { use, expect } from 'chai'; -import sinonChai from 'sinon-chai'; -import chaiAsPromised from 'chai-as-promised'; -import { restore, stub } from 'sinon'; -import { VertexAIBackend } from '../backend'; -import * as generateContentMethods from './generate-content'; -import { TemplateChatSession } from './template-chat-session'; -import { ApiSettings } from '../types/internal'; -import { GenerateContentResult, Part, Role } from '../types'; - -use(sinonChai); -use(chaiAsPromised); - -const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'my-project', - appId: 'my-appid', - location: 'us-central1', - backend: new VertexAIBackend() -}; - -const TEMPLATE_ID = 'my-chat-template'; - -const FAKE_MODEL_RESPONSE_1 = { - response: { - candidates: [ - { - index: 0, - content: { - role: 'model' as Role, - parts: [{ text: 'Response 1' }] - } - } - ] - } -}; - -const FAKE_MODEL_RESPONSE_2 = { - response: { - candidates: [ - { - index: 0, - content: { - role: 'model' as Role, - parts: [{ text: 'Response 2' }] - } - } - ] - } -}; - -describe('TemplateChatSession', () => { - let templateGenerateContentStub: sinon.SinonStub; - let templateGenerateContentStreamStub: sinon.SinonStub; - - beforeEach(() => { - templateGenerateContentStub = stub( - generateContentMethods, - 'templateGenerateContent' - ); - templateGenerateContentStreamStub = stub( - generateContentMethods, - 'templateGenerateContentStream' - ); - }); - - afterEach(() => { - restore(); - }); - - describe('history and state management', () => { - it('should update history correctly after a single successful call', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage('Request 1'); - const history = await chat.getHistory(); - expect(history).to.have.lengthOf(2); - expect(history[0].role).to.equal('user'); - expect(history[0].parts[0].text).to.equal('Request 1'); - expect(history[1].role).to.equal('model'); - expect(history[1].parts[0].text).to.equal('Response 1'); - }); - - it('should maintain history over multiple turns', async () => { - templateGenerateContentStub - .onFirstCall() - .resolves(FAKE_MODEL_RESPONSE_1 as GenerateContentResult); - templateGenerateContentStub - .onSecondCall() - .resolves(FAKE_MODEL_RESPONSE_2 as GenerateContentResult); - - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage('Request 1'); - await chat.sendMessage('Request 2'); - - const history = await chat.getHistory(); - expect(history).to.have.lengthOf(4); - expect(history[0].parts[0].text).to.equal('Request 1'); - expect(history[1].parts[0].text).to.equal('Response 1'); - expect(history[2].parts[0].text).to.equal('Request 2'); - expect(history[3].parts[0].text).to.equal('Response 2'); - }); - - it('should handle sequential calls to sendMessage and sendMessageStream', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_2.response) - }); - - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage('Request 1'); - await chat.sendMessageStream('Request 2'); - - const history = await chat.getHistory(); - expect(history).to.have.lengthOf(4); - expect(history[2].parts[0].text).to.equal('Request 2'); - expect(history[3].parts[0].text).to.equal('Response 2'); - }); - - it('should be able to be initialized with a history', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_2 as GenerateContentResult - ); - const initialHistory = [ - { role: 'user' as Role, parts: [{ text: 'Request 1' }] }, - FAKE_MODEL_RESPONSE_1.response.candidates[0].content - ]; - const chat = new TemplateChatSession( - fakeApiSettings, - TEMPLATE_ID, - initialHistory - ); - await chat.sendMessage('Request 2'); - const history = await chat.getHistory(); - expect(history).to.have.lengthOf(4); - expect(history[0].parts[0].text).to.equal('Request 1'); - expect(history[1].parts[0].text).to.equal('Response 1'); - expect(history[2].parts[0].text).to.equal('Request 2'); - expect(history[3].parts[0].text).to.equal('Response 2'); - }); - }); - - describe('error handling', () => { - it('templateGenerateContent errors should be catchable', async () => { - templateGenerateContentStub.rejects('failed'); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await expect(chat.sendMessage('Request 1')).to.be.rejected; - }); - - it('templateGenerateContentStream errors should be catchable', async () => { - templateGenerateContentStreamStub.rejects('failed'); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await expect(chat.sendMessageStream('Request 1')).to.be.rejected; - }); - - it('getHistory should fail if templateGenerateContent fails', async () => { - templateGenerateContentStub.rejects('failed'); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await expect(chat.sendMessage('Request 1')).to.be.rejected; - await expect(chat.getHistory()).to.be.rejected; - }); - - it('getHistory should fail if templateGenerateContentStream fails', async () => { - templateGenerateContentStreamStub.rejects('failed'); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await expect(chat.sendMessageStream('Request 1')).to.be.rejected; - }); - - it('should not update history if response has no candidates', async () => { - templateGenerateContentStub.resolves({ - response: {} - } as GenerateContentResult); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage('Request 1'); - const history = await chat.getHistory(); - expect(history).to.be.empty; - }); - }); - - describe('input variations for sendMessage', () => { - it('should handle request as a single string', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage('Just a string'); - const history = await chat.getHistory(); - expect(history[0].parts[0].text).to.equal('Just a string'); - }); - - it('should handle request as an array of strings', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage(['string 1', 'string 2']); - const history = await chat.getHistory(); - expect(history[0].parts).to.deep.equal([ - { text: 'string 1' }, - { text: 'string 2' } - ]); - }); - - it('should handle request as an array of Part objects', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const parts: Part[] = [{ text: 'part 1' }, { text: 'part 2' }]; - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessage(parts); - const history = await chat.getHistory(); - expect(history[0].parts).to.deep.equal(parts); - }); - - it('should pass inputs to templateGenerateContent', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - const inputs = { someVar: 'someValue' }; - await chat.sendMessage('A request', inputs); - expect(templateGenerateContentStub).to.have.been.calledWith( - fakeApiSettings, - TEMPLATE_ID, - { - inputs: { ...inputs }, - history: [] - }, - undefined - ); - }); - - it('should pass requestOptions to templateGenerateContent', async () => { - templateGenerateContentStub.resolves( - FAKE_MODEL_RESPONSE_1 as GenerateContentResult - ); - const requestOptions = { timeout: 5000 }; - const chat = new TemplateChatSession( - fakeApiSettings, - TEMPLATE_ID, - [], - requestOptions - ); - await chat.sendMessage('A request'); - expect(templateGenerateContentStub).to.have.been.calledWith( - fakeApiSettings, - TEMPLATE_ID, - { - inputs: {}, - history: [] - }, - requestOptions - ); - }); - }); - - describe('input variations for sendMessageStream', () => { - it('should handle request as a single string', async () => { - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) - }); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessageStream('Just a string'); - const history = await chat.getHistory(); - expect(history[0].parts[0].text).to.equal('Just a string'); - }); - - it('should handle request as an array of strings', async () => { - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) - }); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessageStream(['string 1', 'string 2']); - const history = await chat.getHistory(); - expect(history[0].parts).to.deep.equal([ - { text: 'string 1' }, - { text: 'string 2' } - ]); - }); - - it('should handle request as an array of Part objects', async () => { - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) - }); - const parts: Part[] = [{ text: 'part 1' }, { text: 'part 2' }]; - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - await chat.sendMessageStream(parts); - const history = await chat.getHistory(); - expect(history[0].parts).to.deep.equal(parts); - }); - - it('should pass inputs to templateGenerateContentStream', async () => { - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) - }); - const chat = new TemplateChatSession(fakeApiSettings, TEMPLATE_ID); - const inputs = { someVar: 'someValue' }; - await chat.sendMessageStream('A request', inputs); - expect(templateGenerateContentStreamStub).to.have.been.calledWith( - fakeApiSettings, - TEMPLATE_ID, - { - inputs: { ...inputs }, - history: [] - } - ); - }); - - it('should pass requestOptions to templateGenerateContentStream', async () => { - templateGenerateContentStreamStub.resolves({ - stream: (async function* () {})(), - response: Promise.resolve(FAKE_MODEL_RESPONSE_1.response) - }); - const requestOptions = { timeout: 5000 }; - const chat = new TemplateChatSession( - fakeApiSettings, - TEMPLATE_ID, - [], - requestOptions - ); - await chat.sendMessageStream('A request'); - expect(templateGenerateContentStreamStub).to.have.been.calledWith( - fakeApiSettings, - TEMPLATE_ID, - { - inputs: {}, - history: [] - }, - requestOptions - ); - }); - }); -}); diff --git a/packages/ai/src/methods/template-chat-session.ts b/packages/ai/src/methods/template-chat-session.ts deleted file mode 100644 index 7c8e4a21252..00000000000 --- a/packages/ai/src/methods/template-chat-session.ts +++ /dev/null @@ -1,230 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - Content, - GenerateContentResult, - GenerateContentStreamResult, - Part, - RequestOptions -} from '../types'; -import { formatNewContent } from '../requests/request-helpers'; -import { formatBlockErrorMessage } from '../requests/response-helpers'; -import { validateChatHistory } from './chat-session-helpers'; -import { - templateGenerateContent, - templateGenerateContentStream -} from './generate-content'; -import { ApiSettings } from '../types/internal'; -import { logger } from '../logger'; - -/** - * Do not log a message for this error. - */ -const SILENT_ERROR = 'SILENT_ERROR'; - -/** - * A chat session that enables sending chat messages and stores the history of - * sent and received messages so far. - * - * This session is for multi-turn chats using a server-side template. It should - * be instantiated with {@link TemplateGenerativeModel.startChat}. - * - * @beta - */ -export class TemplateChatSession { - private _sendPromise: Promise = Promise.resolve(); - - /** - * @hideconstructor - */ - constructor( - private _apiSettings: ApiSettings, - public templateId: string, - private _history: Content[] = [], - public requestOptions?: RequestOptions - ) { - if (this._history) { - validateChatHistory(this._history); - } - } - - /** - * Gets the chat history so far. Blocked prompts are not added to history. - * Neither blocked candidates nor the prompts that generated them are added - * to history. - * - * @beta - */ - async getHistory(): Promise { - await this._sendPromise; - return this._history; - } - - /** - * Sends a chat message and receives a non-streaming - * {@link GenerateContentResult}. - * - * @param request - The user message to store in the history - * @param inputs - A key-value map of variables to populate the template - * with. This should likely include the user message. - * - * @beta - */ - async sendMessage( - request: string | Array, - inputs?: object - ): Promise { - await this._sendPromise; - let finalResult = {} as GenerateContentResult; - const variablesWithHistory = { - inputs: { - ...inputs - }, - history: [...this._history] - }; - // Add onto the chain. - this._sendPromise = this._sendPromise - .then(() => - templateGenerateContent( - this._apiSettings, - this.templateId, - variablesWithHistory, - this.requestOptions - ) - ) - .then(result => { - if ( - result.response.candidates && - result.response.candidates.length > 0 - ) { - // Important note: The user's message is *not* the actual message that was sent to - // the model, but the message that was passed as a parameter. - // Since the real message was the rendered server prompt template, there is no way - // to store the actual message in the client. - // It's the user's responsibility to ensure that the `message` that goes in the history - // is as close as possible to the rendered template if they want a realistic chat - // experience. - // The ideal case here is that the user defines a `message` variable in the `inputs` of - // the prompt template. The other parts of the message that the prompt template is hiding - // isn't relevant to the conversation history. For example, system instructions. - // In this case, the user would have the user's `message` that they pass as the first - // argument to this method, then *also* pass that in the `inputs`, so that it's actually - // part of the populated template that is sent to the model. - this._history.push(formatNewContent(request)); - const responseContent: Content = { - parts: result.response.candidates?.[0].content.parts || [], - // Response seems to come back without a role set. - role: result.response.candidates?.[0].content.role || 'model' - }; - this._history.push(responseContent); - } else { - const blockErrorMessage = formatBlockErrorMessage(result.response); - if (blockErrorMessage) { - logger.warn( - `sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.` - ); - } - } - finalResult = result as GenerateContentResult; - }); - await this._sendPromise; - return finalResult; - } - - /** - * Sends a chat message and receives the response as a - * {@link GenerateContentStreamResult} containing an iterable stream - * and a response promise. - * - * @param request - The message to send to the model. - * @param inputs - A key-value map of variables to populate the template - * with. - * - * @beta - */ - async sendMessageStream( - request: string | Array, - inputs?: object - ): Promise { - await this._sendPromise; - const variablesWithHistory = { - inputs: { - ...inputs - }, - history: [...this._history] - }; - const streamPromise = templateGenerateContentStream( - this._apiSettings, - this.templateId, - variablesWithHistory, - this.requestOptions - ); - - // Add onto the chain. - this._sendPromise = this._sendPromise - .then(() => streamPromise) - // This must be handled to avoid unhandled rejection, but jump - // to the final catch block with a label to not log this error. - .catch(_ignored => { - throw new Error(SILENT_ERROR); - }) - .then(streamResult => streamResult.response) - .then(response => { - if (response.candidates && response.candidates.length > 0) { - // Important note: The user's message is *not* the actual message that was sent to - // the model, but the message that was passed as a parameter. - // Since the real message was the rendered server prompt template, there is no way - // to store the actual message in the client. - // It's the user's responsibility to ensure that the `message` that goes in the history - // is as close as possible to the rendered template if they want a realistic chat - // experience. - // The ideal case here is that the user defines a `message` variable in the `inputs` of - // the prompt template. The other parts of the message that the prompt template is hiding - // isn't relevant to the conversation history. For example, system instructions. - // In this case, the user would have the user's `message` that they pass as the first - // argument to this method, then *also* pass that in the `inputs`, so that it's actually - // part of the populated template that is sent to the model. - this._history.push(formatNewContent(request)); - const responseContent = { ...response.candidates[0].content }; - // Response seems to come back without a role set. - if (!responseContent.role) { - responseContent.role = 'model'; - } - this._history.push(responseContent); - } else { - const blockErrorMessage = formatBlockErrorMessage(response); - if (blockErrorMessage) { - logger.warn( - `sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.` - ); - } - } - }) - .catch(e => { - // Errors in streamPromise are already catchable by the user as - // streamPromise is returned. - // Avoid duplicating the error message in logs. - if (e.message !== SILENT_ERROR) { - // Users do not have access to _sendPromise to catch errors - // downstream from streamPromise, so they should not throw. - logger.error(e); - } - }); - return streamPromise; - } -} diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts index 4198090eab6..c3eb43af491 100644 --- a/packages/ai/src/models/template-generative-model.test.ts +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -18,11 +18,10 @@ import { use, expect } from 'chai'; import sinonChai from 'sinon-chai'; import { restore, stub } from 'sinon'; -import { AI, Content } from '../public-types'; +import { AI } from '../public-types'; import { VertexAIBackend } from '../backend'; import { TemplateGenerativeModel } from './template-generative-model'; import * as generateContentMethods from '../methods/generate-content'; -import { TemplateChatSession } from '../methods/template-chat-session'; use(sinonChai); @@ -94,24 +93,4 @@ describe('TemplateGenerativeModel', () => { ); }); }); - - describe('startChat', () => { - it('should return a TemplateChatSession instance', () => { - const model = new TemplateGenerativeModel(fakeAI); - const chat = model.startChat(TEMPLATE_ID); - expect(chat).to.be.an.instanceOf(TemplateChatSession); - expect(chat.templateId).to.equal(TEMPLATE_ID); - }); - - it('should pass history and requestOptions to TemplateChatSession', () => { - const history: Content[] = [{ role: 'user', parts: [{ text: 'hi' }] }]; - const requestOptions = { timeout: 1000 }; - const model = new TemplateGenerativeModel(fakeAI, requestOptions); - const chat = model.startChat(TEMPLATE_ID, history); - - expect(chat.requestOptions).to.deep.equal(requestOptions); - // Private property, but we can check it for test purposes - expect((chat as any)._history).to.deep.equal(history); - }); - }); }); diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts index 8a60728df2c..ec9e653618d 100644 --- a/packages/ai/src/models/template-generative-model.ts +++ b/packages/ai/src/models/template-generative-model.ts @@ -20,9 +20,8 @@ import { templateGenerateContentStream } from '../methods/generate-content'; import { GenerateContentResult, RequestOptions } from '../types'; -import { AI, Content, GenerateContentStreamResult } from '../public-types'; +import { AI, GenerateContentStreamResult } from '../public-types'; import { ApiSettings } from '../types/internal'; -import { TemplateChatSession } from '../methods/template-chat-session'; import { initApiSettings } from './utils'; /** @@ -96,23 +95,4 @@ export class TemplateGenerativeModel { this.requestOptions ); } - - /** - * Gets a new {@link TemplateChatSession} instance which can be used for - * multi-turn chats. - * - * @param templateId - The ID of the server-side template to execute. - * @param history - An array of {@link Content} objects to initialize the - * chat history with. - * - * @beta - */ - startChat(templateId: string, history?: Content[]): TemplateChatSession { - return new TemplateChatSession( - this._apiSettings, - templateId, - history, - this.requestOptions - ); - } } From b9559460cf32f03c3bb85f5099beb41587df0986 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 30 Oct 2025 11:17:08 -0400 Subject: [PATCH 6/8] Fix task ref --- packages/ai/src/requests/request.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index bafb97bf855..d1988181573 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -19,7 +19,7 @@ import { expect, use } from 'chai'; import { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { RequestURL, Task, getHeaders, makeRequest } from './request'; +import { RequestURL, ServerPromptTemplateTask, Task, getHeaders, makeRequest } from './request'; import { ApiSettings } from '../types/internal'; import { DEFAULT_API_VERSION } from '../constants'; import { AIErrorCode } from '../types'; @@ -103,7 +103,7 @@ describe('request methods', () => { it('prompt server template', async () => { const url = new RequestURL({ templateId: 'my-template', - task: 'templateGenerateContent', + task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, requestOptions: {} From 47ee3942aca909bed60c5a208a6c5f096660b3d4 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 30 Oct 2025 11:23:00 -0400 Subject: [PATCH 7/8] Run formatter and generate docs --- docs-devsite/_toc.yaml | 2 - docs-devsite/ai.md | 1 - docs-devsite/ai.templatechatsession.md | 154 --------------------- docs-devsite/ai.templategenerativemodel.md | 25 ---- packages/ai/src/requests/request.test.ts | 8 +- 5 files changed, 7 insertions(+), 183 deletions(-) delete mode 100644 docs-devsite/ai.templatechatsession.md diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index cc52bb4d395..92633c553a3 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -198,8 +198,6 @@ toc: path: /docs/reference/js/ai.startchatparams.md - title: StringSchema path: /docs/reference/js/ai.stringschema.md - - title: TemplateChatSession - path: /docs/reference/js/ai.templatechatsession.md - title: TemplateGenerativeModel path: /docs/reference/js/ai.templategenerativemodel.md - title: TemplateImagenModel diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index 3fb7452fe38..53e4057cade 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -49,7 +49,6 @@ The Firebase AI Web SDK. | [ObjectSchema](./ai.objectschema.md#objectschema_class) | Schema class for "object" types. The properties param must be a map of Schema objects. | | [Schema](./ai.schema.md#schema_class) | Parent class encompassing all Schema types, with static methods that allow building specific Schema types. This class can be converted with JSON.stringify() into a JSON string accepted by Vertex AI REST endpoints. (This string conversion is automatically done when calling SDK methods.) | | [StringSchema](./ai.stringschema.md#stringschema_class) | Schema class for "string" types. Can be used with or without enum values. | -| [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) | (Public Preview) A chat session that enables sending chat messages and stores the history of sent and received messages so far.This session is for multi-turn chats using a server-side template. It should be instantiated with [TemplateGenerativeModel.startChat()](./ai.templategenerativemodel.md#templategenerativemodelstartchat). | | [TemplateGenerativeModel](./ai.templategenerativemodel.md#templategenerativemodel_class) | (Public Preview) [GenerativeModel](./ai.generativemodel.md#generativemodel_class) APIs that execute on a server-side template.This class should only be instantiated with [getTemplateGenerativeModel()](./ai.md#gettemplategenerativemodel_9476bbc). | | [TemplateImagenModel](./ai.templateimagenmodel.md#templateimagenmodel_class) | (Public Preview) Class for Imagen model APIs that execute on a server-side template.This class should only be instantiated with [getTemplateImagenModel()](./ai.md#gettemplateimagenmodel_9476bbc). | | [VertexAIBackend](./ai.vertexaibackend.md#vertexaibackend_class) | Configuration class for the Vertex AI Gemini API.Use this with [AIOptions](./ai.aioptions.md#aioptions_interface) when initializing the AI service via [getAI()](./ai.md#getai_a94a413) to specify the Vertex AI Gemini API as the backend. | diff --git a/docs-devsite/ai.templatechatsession.md b/docs-devsite/ai.templatechatsession.md deleted file mode 100644 index 41f5e71d97a..00000000000 --- a/docs-devsite/ai.templatechatsession.md +++ /dev/null @@ -1,154 +0,0 @@ -Project: /docs/reference/js/_project.yaml -Book: /docs/reference/_book.yaml -page_type: reference - -{% comment %} -DO NOT EDIT THIS FILE! -This is generated by the JS SDK team, and any local changes will be -overwritten. Changes should be made in the source code at -https://github.com/firebase/firebase-js-sdk -{% endcomment %} - -# TemplateChatSession class -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -A chat session that enables sending chat messages and stores the history of sent and received messages so far. - -This session is for multi-turn chats using a server-side template. It should be instantiated with [TemplateGenerativeModel.startChat()](./ai.templategenerativemodel.md#templategenerativemodelstartchat). - -Signature: - -```typescript -export declare class TemplateChatSession -``` - -## Constructors - -| Constructor | Modifiers | Description | -| --- | --- | --- | -| [(constructor)(\_apiSettings, templateId, \_history, requestOptions)](./ai.templatechatsession.md#templatechatsessionconstructor) | | (Public Preview) Constructs a new instance of the TemplateChatSession class | - -## Properties - -| Property | Modifiers | Type | Description | -| --- | --- | --- | --- | -| [requestOptions](./ai.templatechatsession.md#templatechatsessionrequestoptions) | | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) \| undefined | (Public Preview) | -| [templateId](./ai.templatechatsession.md#templatechatsessiontemplateid) | | string | (Public Preview) | - -## Methods - -| Method | Modifiers | Description | -| --- | --- | --- | -| [getHistory()](./ai.templatechatsession.md#templatechatsessiongethistory) | | (Public Preview) Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. | -| [sendMessage(request, inputs)](./ai.templatechatsession.md#templatechatsessionsendmessage) | | (Public Preview) Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface). | -| [sendMessageStream(request, inputs)](./ai.templatechatsession.md#templatechatsessionsendmessagestream) | | (Public Preview) Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | - -## TemplateChatSession.(constructor) - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - - Constructs a new instance of the `TemplateChatSession` class - -Signature: - -```typescript -constructor(_apiSettings: ApiSettings, templateId: string, _history?: Content[], requestOptions?: RequestOptions | undefined); -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| \_apiSettings | ApiSettings | | -| templateId | string | | -| \_history | [Content](./ai.content.md#content_interface)\[\] | | -| requestOptions | [RequestOptions](./ai.requestoptions.md#requestoptions_interface) \| undefined | | - -## TemplateChatSession.requestOptions - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Signature: - -```typescript -requestOptions?: RequestOptions | undefined; -``` - -## TemplateChatSession.templateId - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Signature: - -```typescript -templateId: string; -``` - -## TemplateChatSession.getHistory() - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. - -Signature: - -```typescript -getHistory(): Promise; -``` -Returns: - -Promise<[Content](./ai.content.md#content_interface)\[\]> - -## TemplateChatSession.sendMessage() - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface). - -Signature: - -```typescript -sendMessage(request: string | Array, inputs?: object): Promise; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| request | string \| Array<string \| [Part](./ai.md#part)> | The user message to store in the history | -| inputs | object | A key-value map of variables to populate the template with. This should likely include the user message. | - -Returns: - -Promise<[GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface)> - -## TemplateChatSession.sendMessageStream() - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. - -Signature: - -```typescript -sendMessageStream(request: string | Array, inputs?: object): Promise; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| request | string \| Array<string \| [Part](./ai.md#part)> | The message to send to the model. | -| inputs | object | A key-value map of variables to populate the template with. | - -Returns: - -Promise<[GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface)> - diff --git a/docs-devsite/ai.templategenerativemodel.md b/docs-devsite/ai.templategenerativemodel.md index d7ab0955f2f..c115af62b1e 100644 --- a/docs-devsite/ai.templategenerativemodel.md +++ b/docs-devsite/ai.templategenerativemodel.md @@ -41,7 +41,6 @@ export declare class TemplateGenerativeModel | --- | --- | --- | | [generateContent(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | | [generateContentStream(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | -| [startChat(templateId, history)](./ai.templategenerativemodel.md#templategenerativemodelstartchat) | | (Public Preview) Gets a new [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) instance which can be used for multi-turn chats. | ## TemplateGenerativeModel.(constructor) @@ -124,27 +123,3 @@ generateContentStream(templateId: string, templateVariables: object): Promise> -## TemplateGenerativeModel.startChat() - -> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. -> - -Gets a new [TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) instance which can be used for multi-turn chats. - -Signature: - -```typescript -startChat(templateId: string, history?: Content[]): TemplateChatSession; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| templateId | string | The ID of the server-side template to execute. | -| history | [Content](./ai.content.md#content_interface)\[\] | An array of [Content](./ai.content.md#content_interface) objects to initialize the chat history with. | - -Returns: - -[TemplateChatSession](./ai.templatechatsession.md#templatechatsession_class) - diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index d1988181573..a54ff521bea 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -19,7 +19,13 @@ import { expect, use } from 'chai'; import { match, restore, stub } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { RequestURL, ServerPromptTemplateTask, Task, getHeaders, makeRequest } from './request'; +import { + RequestURL, + ServerPromptTemplateTask, + Task, + getHeaders, + makeRequest +} from './request'; import { ApiSettings } from '../types/internal'; import { DEFAULT_API_VERSION } from '../constants'; import { AIErrorCode } from '../types'; From 7e5213446bf0fc2eb3f649d283bf0efa5ffc88e5 Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Tue, 4 Nov 2025 15:00:46 -0500 Subject: [PATCH 8/8] Fix docs --- common/api-review/ai.api.md | 8 +++--- docs-devsite/ai.googleaibackend.md | 46 ------------------------------ docs-devsite/ai.vertexaibackend.md | 46 ------------------------------ packages/ai/src/backend.ts | 12 ++++++++ 4 files changed, 16 insertions(+), 96 deletions(-) diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 4057a85bb9a..2bf194fbaf2 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -580,9 +580,9 @@ export function getTemplateImagenModel(ai: AI, requestOptions?: RequestOptions): // @public export class GoogleAIBackend extends Backend { constructor(); - // (undocumented) + // @internal (undocumented) _getModelPath(project: string, model: string): string; - // (undocumented) + // @internal (undocumented) _getTemplatePath(project: string, templateId: string): string; } @@ -1445,9 +1445,9 @@ export interface UsageMetadata { // @public export class VertexAIBackend extends Backend { constructor(location?: string); - // (undocumented) + // @internal (undocumented) _getModelPath(project: string, model: string): string; - // (undocumented) + // @internal (undocumented) _getTemplatePath(project: string, templateId: string): string; readonly location: string; } diff --git a/docs-devsite/ai.googleaibackend.md b/docs-devsite/ai.googleaibackend.md index 68d6724762a..7ccf8834a0a 100644 --- a/docs-devsite/ai.googleaibackend.md +++ b/docs-devsite/ai.googleaibackend.md @@ -27,13 +27,6 @@ export declare class GoogleAIBackend extends Backend | --- | --- | --- | | [(constructor)()](./ai.googleaibackend.md#googleaibackendconstructor) | | Creates a configuration object for the Gemini Developer API backend. | -## Methods - -| Method | Modifiers | Description | -| --- | --- | --- | -| [\_getModelPath(project, model)](./ai.googleaibackend.md#googleaibackend_getmodelpath) | | | -| [\_getTemplatePath(project, templateId)](./ai.googleaibackend.md#googleaibackend_gettemplatepath) | | | - ## GoogleAIBackend.(constructor) Creates a configuration object for the Gemini Developer API backend. @@ -43,42 +36,3 @@ Creates a configuration object for the Gemini Developer API backend. ```typescript constructor(); ``` - -## GoogleAIBackend.\_getModelPath() - -Signature: - -```typescript -_getModelPath(project: string, model: string): string; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| project | string | | -| model | string | | - -Returns: - -string - -## GoogleAIBackend.\_getTemplatePath() - -Signature: - -```typescript -_getTemplatePath(project: string, templateId: string): string; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| project | string | | -| templateId | string | | - -Returns: - -string - diff --git a/docs-devsite/ai.vertexaibackend.md b/docs-devsite/ai.vertexaibackend.md index e2e7fae1839..88424b75c45 100644 --- a/docs-devsite/ai.vertexaibackend.md +++ b/docs-devsite/ai.vertexaibackend.md @@ -33,13 +33,6 @@ export declare class VertexAIBackend extends Backend | --- | --- | --- | --- | | [location](./ai.vertexaibackend.md#vertexaibackendlocation) | | string | The region identifier. See [Vertex AI locations](https://firebase.google.com/docs/vertex-ai/locations#available-locations) for a list of supported locations. | -## Methods - -| Method | Modifiers | Description | -| --- | --- | --- | -| [\_getModelPath(project, model)](./ai.vertexaibackend.md#vertexaibackend_getmodelpath) | | | -| [\_getTemplatePath(project, templateId)](./ai.vertexaibackend.md#vertexaibackend_gettemplatepath) | | | - ## VertexAIBackend.(constructor) Creates a configuration object for the Vertex AI backend. @@ -65,42 +58,3 @@ The region identifier. See [Vertex AI locations](https://firebase.google.com/doc ```typescript readonly location: string; ``` - -## VertexAIBackend.\_getModelPath() - -Signature: - -```typescript -_getModelPath(project: string, model: string): string; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| project | string | | -| model | string | | - -Returns: - -string - -## VertexAIBackend.\_getTemplatePath() - -Signature: - -```typescript -_getTemplatePath(project: string, templateId: string): string; -``` - -#### Parameters - -| Parameter | Type | Description | -| --- | --- | --- | -| project | string | | -| templateId | string | | - -Returns: - -string - diff --git a/packages/ai/src/backend.ts b/packages/ai/src/backend.ts index 21852b3608d..2eaec59448f 100644 --- a/packages/ai/src/backend.ts +++ b/packages/ai/src/backend.ts @@ -67,10 +67,16 @@ export class GoogleAIBackend extends Backend { super(BackendType.GOOGLE_AI); } + /** + * @internal + */ _getModelPath(project: string, model: string): string { return `/${DEFAULT_API_VERSION}/projects/${project}/${model}`; } + /** + * @internal + */ _getTemplatePath(project: string, templateId: string): string { return `/${DEFAULT_API_VERSION}/projects/${project}/templates/${templateId}`; } @@ -108,10 +114,16 @@ export class VertexAIBackend extends Backend { } } + /** + * @internal + */ _getModelPath(project: string, model: string): string { return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/${model}`; } + /** + * @internal + */ _getTemplatePath(project: string, templateId: string): string { return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/templates/${templateId}`; }