Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Currently, we support the following providers:
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [DeepInfra](https://deepinfra.com)
- [Groq](https://groq.com)
- [Wavespeed.ai](https://wavespeed.ai/)
- [Z.ai](https://z.ai/)
Expand Down Expand Up @@ -104,6 +105,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Clarifai supported models](https://huggingface.co/api/partners/clarifai/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [DeepInfra supported models](https://huggingface.co/api/partners/deepinfra/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed/models)
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as Clarifai from "../providers/clarifai.js";
import * as BlackForestLabs from "../providers/black-forest-labs.js";
import * as Cerebras from "../providers/cerebras.js";
import * as Cohere from "../providers/cohere.js";
import * as DeepInfra from "../providers/deepinfra.js";
import * as FalAI from "../providers/fal-ai.js";
import * as FeatherlessAI from "../providers/featherless-ai.js";
import * as Fireworks from "../providers/fireworks-ai.js";
Expand Down Expand Up @@ -73,6 +74,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
cohere: {
conversational: new Cohere.CohereConversationalTask(),
},
deepinfra: {
conversational: new DeepInfra.DeepInfraConversationalTask(),
"text-generation": new DeepInfra.DeepInfraTextGenerationTask(),
},
"fal-ai": {
"text-to-image": new FalAI.FalAITextToImageTask(),
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
cerebras: {},
clarifai: {},
cohere: {},
deepinfra: {},
"fal-ai": {},
"featherless-ai": {},
"fireworks-ai": {},
Expand Down
78 changes: 78 additions & 0 deletions packages/inference/src/providers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import type { TextGenerationOutput } from "@huggingface/tasks";
import { InferenceClientProviderOutputError } from "../errors.js";
import type { BodyParams } from "../types.js";
import { omit } from "../utils/omit.js";
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper.js";

/**
* DeepInfra exposes OpenAI-compatible endpoints under the /v1/openai namespace.
*/
const DEEPINFRA_API_BASE_URL = "https://api.deepinfra.com";

interface DeepInfraCompletionChoice {
text?: string;
message?: { content?: string };
}

interface DeepInfraCompletionResponse {
choices: DeepInfraCompletionChoice[];
model: string;
}

export class DeepInfraConversationalTask extends BaseConversationalTask {
constructor() {
super("deepinfra", DEEPINFRA_API_BASE_URL);
}

override makeRoute(): string {
return "v1/openai/chat/completions";
}
}

export class DeepInfraTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("deepinfra", DEEPINFRA_API_BASE_URL);
}

override makeRoute(): string {
return "v1/openai/completions";
}

override preparePayload(params: BodyParams): Record<string, unknown> {
const parameters = params.args.parameters as Record<string, unknown> | undefined;
const res = {
model: params.model,
prompt: params.args.inputs,
...omit(params.args, ["inputs", "parameters"]),
...(parameters
? {
max_tokens: parameters.max_new_tokens,
stop: parameters.stop_strings,
...omit(parameters, ["max_new_tokens", "stop_strings"]),
}
: undefined),
};
return res;
}

override async getResponse(response: DeepInfraCompletionResponse): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
response !== null &&
Array.isArray(response.choices) &&
response.choices.length > 0
) {
const choice = response.choices[0];
const completion =
choice.text ??
(typeof choice.message?.content === "string" ? choice.message.content : undefined);
if (typeof completion === "string") {
return { generated_text: completion };
}
}

throw new InferenceClientProviderOutputError(
"Received malformed response from DeepInfra text-generation API: expected OpenAI completion payload"
);
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export const INFERENCE_PROVIDERS = [
"cerebras",
"clarifai",
"cohere",
"deepinfra",
"fal-ai",
"featherless-ai",
"fireworks-ai",
Expand Down
76 changes: 76 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1967,6 +1967,82 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"DeepInfra",
() => {
const client = new InferenceClient(
env.DEEPINFRA_API_KEY ?? env.HF_DEEPINFRA_KEY ?? "dummy"
);

const HF_MODEL = "google/gemma-3-4b-it";
const PROVIDER_ID = "google/gemma-3-4b-it";

const setMapping = (task: "conversational" | "text-generation") => {
HARDCODED_MODEL_INFERENCE_MAPPING["deepinfra"] = {
[HF_MODEL]: {
provider: "deepinfra",
hfModelId: HF_MODEL,
providerId: PROVIDER_ID,
status: "live",
task,
},
};
};

it("chatCompletion", async () => {
setMapping("conversational");
const res = await client.chatCompletion({
model: HF_MODEL,
provider: "deepinfra",
messages: [{ role: "user", content: "Name one use case for open-source AI." }],
max_tokens: 64,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(typeof completion).toBe("string");
expect(completion?.length ?? 0).toBeGreaterThan(0);
}
});

it("chatCompletion stream", async () => {
setMapping("conversational");
const stream = client.chatCompletionStream({
model: HF_MODEL,
provider: "deepinfra",
messages: [
{
role: "user",
content: "Respond with a two-word description of Hugging Face.",
},
],
max_tokens: 32,
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let streamed = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const delta = chunk.choices[0].delta?.content;
if (delta) {
streamed += delta;
}
}
}
expect(streamed.trim().length).toBeGreaterThan(0);
});

it("textGeneration", async () => {
setMapping("text-generation");
const generation = await client.textGeneration({
model: HF_MODEL,
provider: "deepinfra",
inputs: "Describe Hugging Face in one short sentence.",
parameters: { max_new_tokens: 64 },
});
expect(typeof generation.generated_text).toBe("string");
expect(generation.generated_text?.length ?? 0).toBeGreaterThan(0);
});
},
TIMEOUT
);
describe.concurrent(
"Cerebras",
() => {
Expand Down
1 change: 1 addition & 0 deletions packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
const INFERENCE_PROVIDERS = [
"cerebras",
"cohere",
"deepinfra",
"fal-ai",
"fireworks-ai",
"hf-inference",
Expand Down