From 33a14a34006c7e48abaa093b95f1d830a80adbed Mon Sep 17 00:00:00 2001 From: Thach Nguyen Date: Tue, 14 Oct 2025 21:54:42 +0000 Subject: [PATCH] feat: Add DeepInfra as an inference provider --- packages/inference/README.md | 2 + .../inference/src/lib/getProviderHelper.ts | 5 ++ packages/inference/src/providers/consts.ts | 1 + packages/inference/src/providers/deepinfra.ts | 78 +++++++++++++++++++ packages/inference/src/types.ts | 1 + .../inference/test/InferenceClient.spec.ts | 76 ++++++++++++++++++ packages/tasks/src/inference-providers.ts | 1 + 7 files changed, 164 insertions(+) create mode 100644 packages/inference/src/providers/deepinfra.ts diff --git a/packages/inference/README.md b/packages/inference/README.md index 664c224583..b16fef1e7b 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -66,6 +66,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [DeepInfra](https://deepinfra.com) - [Groq](https://groq.com) - [Wavespeed.ai](https://wavespeed.ai/) - [Z.ai](https://z.ai/) @@ -104,6 +105,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Clarifai supported models](https://huggingface.co/api/partners/clarifai/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [DeepInfra supported models](https://huggingface.co/api/partners/deepinfra/models) - [Groq supported models](https://console.groq.com/docs/models) - [Novita AI supported models](https://huggingface.co/api/partners/novita/models) - [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 5f5f16b044..85164a2afb 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -3,6 +3,7 @@ import * as Clarifai from "../providers/clarifai.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; +import * as DeepInfra from "../providers/deepinfra.js"; import * as FalAI from "../providers/fal-ai.js"; import * as FeatherlessAI from "../providers/featherless-ai.js"; import * as Fireworks from "../providers/fireworks-ai.js"; @@ -73,6 +74,10 @@ export const PROVIDERS: Record { + const parameters = params.args.parameters as Record | undefined; + const res = { + model: params.model, + prompt: params.args.inputs, + ...omit(params.args, ["inputs", "parameters"]), + ...(parameters + ? { + max_tokens: parameters.max_new_tokens, + stop: parameters.stop_strings, + ...omit(parameters, ["max_new_tokens", "stop_strings"]), + } + : undefined), + }; + return res; + } + + override async getResponse(response: DeepInfraCompletionResponse): Promise { + if ( + typeof response === "object" && + response !== null && + Array.isArray(response.choices) && + response.choices.length > 0 + ) { + const choice = response.choices[0]; + const completion = + choice.text ?? + (typeof choice.message?.content === "string" ? choice.message.content : undefined); + if (typeof completion === "string") { + return { generated_text: completion }; + } + } + + throw new InferenceClientProviderOutputError( + "Received malformed response from DeepInfra text-generation API: expected OpenAI completion payload" + ); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 29df9f09ec..016a28a04c 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -50,6 +50,7 @@ export const INFERENCE_PROVIDERS = [ "cerebras", "clarifai", "cohere", + "deepinfra", "fal-ai", "featherless-ai", "fireworks-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index fcb6e55cb3..58c5c16719 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1967,6 +1967,82 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "DeepInfra", + () => { + const client = new InferenceClient( + env.DEEPINFRA_API_KEY ?? env.HF_DEEPINFRA_KEY ?? "dummy" + ); + + const HF_MODEL = "google/gemma-3-4b-it"; + const PROVIDER_ID = "google/gemma-3-4b-it"; + + const setMapping = (task: "conversational" | "text-generation") => { + HARDCODED_MODEL_INFERENCE_MAPPING["deepinfra"] = { + [HF_MODEL]: { + provider: "deepinfra", + hfModelId: HF_MODEL, + providerId: PROVIDER_ID, + status: "live", + task, + }, + }; + }; + + it("chatCompletion", async () => { + setMapping("conversational"); + const res = await client.chatCompletion({ + model: HF_MODEL, + provider: "deepinfra", + messages: [{ role: "user", content: "Name one use case for open-source AI." }], + max_tokens: 64, + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(typeof completion).toBe("string"); + expect(completion?.length ?? 0).toBeGreaterThan(0); + } + }); + + it("chatCompletion stream", async () => { + setMapping("conversational"); + const stream = client.chatCompletionStream({ + model: HF_MODEL, + provider: "deepinfra", + messages: [ + { + role: "user", + content: "Respond with a two-word description of Hugging Face.", + }, + ], + max_tokens: 32, + }) as AsyncGenerator; + let streamed = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const delta = chunk.choices[0].delta?.content; + if (delta) { + streamed += delta; + } + } + } + expect(streamed.trim().length).toBeGreaterThan(0); + }); + + it("textGeneration", async () => { + setMapping("text-generation"); + const generation = await client.textGeneration({ + model: HF_MODEL, + provider: "deepinfra", + inputs: "Describe Hugging Face in one short sentence.", + parameters: { max_new_tokens: 64 }, + }); + expect(typeof generation.generated_text).toBe("string"); + expect(generation.generated_text?.length ?? 0).toBeGreaterThan(0); + }); + }, + TIMEOUT + ); describe.concurrent( "Cerebras", () => { diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index ee08d12943..a7a0c2c26e 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -3,6 +3,7 @@ const INFERENCE_PROVIDERS = [ "cerebras", "cohere", + "deepinfra", "fal-ai", "fireworks-ai", "hf-inference",