Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Currently, we support the following providers:
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
- [SiliconFlow](https://siliconflow.com)
- [Clarifai](http://clarifai.com)
- [Together](https://together.xyz)
- [Baseten](https://baseten.co)
Expand Down Expand Up @@ -99,6 +100,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
- [SiliconFlow supported models](https://huggingface.co/api/partners/siliconflow/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Baseten supported models](https://huggingface.co/api/partners/baseten/models)
- [Clarifai supported models](https://huggingface.co/api/partners/clarifai/models)
Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import * as Nscale from "../providers/nscale.js";
import * as OpenAI from "../providers/openai.js";
import * as OvhCloud from "../providers/ovhcloud.js";
import * as PublicAI from "../providers/publicai.js";
import * as SiliconFlow from "../providers/siliconflow.js";
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
Expand Down Expand Up @@ -169,6 +170,9 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-generation": new Scaleway.ScalewayTextGenerationTask(),
"feature-extraction": new Scaleway.ScalewayFeatureExtractionTask(),
},
siliconflow: {
conversational: new SiliconFlow.SiliconFlowConversationalTask(),
},
together: {
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
replicate: {},
sambanova: {},
scaleway: {},
siliconflow: {},
together: {},
wavespeed: {},
"zai-org": {},
Expand Down
16 changes: 16 additions & 0 deletions packages/inference/src/providers/siliconflow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/**
* SiliconFlow provider implementation
*
* API Documentation: https://docs.siliconflow.com
*
* SiliconFlow follows the OpenAI API standard for LLMs.
*/
import { BaseConversationalTask } from "./providerHelper.js";

const SILICONFLOW_API_BASE_URL = "https://api.siliconflow.com";

export class SiliconFlowConversationalTask extends BaseConversationalTask {
constructor() {
super("siliconflow", SILICONFLOW_API_BASE_URL);
}
}
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export const INFERENCE_PROVIDERS = [
"replicate",
"sambanova",
"scaleway",
"siliconflow",
"together",
"wavespeed",
"zai-org",
Expand Down Expand Up @@ -102,6 +103,7 @@ export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
replicate: "replicate",
sambanova: "sambanovasystems",
scaleway: "scaleway",
siliconflow: "siliconflow",
together: "togethercomputer",
wavespeed: "wavespeed",
"zai-org": "zai-org",
Expand Down
95 changes: 95 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2572,4 +2572,99 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"SiliconFlow",
() => {
const client = new InferenceClient(env.HF_SILICONFLOW_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["siliconflow"] = {
"deepseek-ai/DeepSeek-R1": {
provider: "siliconflow",
hfModelId: "deepseek-ai/DeepSeek-R1",
providerId: "deepseek-ai/DeepSeek-R1",
status: "live",
task: "conversational",
},
"deepseek-ai/DeepSeek-V3": {
provider: "siliconflow",
hfModelId: "deepseek-ai/DeepSeek-V3",
providerId: "deepseek-ai/DeepSeek-V3",
status: "live",
task: "conversational",
},
};

it("chatCompletion - DeepSeek-R1", async () => {
const res = await client.chatCompletion({
model: "deepseek-ai/DeepSeek-R1",
provider: "siliconflow",
messages: [{ role: "user", content: "What is the capital of France?" }],
max_tokens: 20,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/Paris/i);
}
});

it("chatCompletion - DeepSeek-V3", async () => {
const res = await client.chatCompletion({
model: "deepseek-ai/DeepSeek-V3",
provider: "siliconflow",
messages: [{ role: "user", content: "The weather today is" }],
max_tokens: 10,
});
expect(res.choices).toBeDefined();
expect(res.choices?.length).toBeGreaterThan(0);
expect(res.choices?.[0].message?.content).toBeDefined();
expect(typeof res.choices?.[0].message?.content).toBe("string");
expect(res.choices?.[0].message?.content?.length).toBeGreaterThan(0);
});

it("chatCompletion stream - DeepSeek-R1", async () => {
const stream = client.chatCompletionStream({
model: "deepseek-ai/DeepSeek-R1",
provider: "siliconflow",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});

it("chatCompletion stream - DeepSeek-V3", async () => {
const stream = client.chatCompletionStream({
model: "deepseek-ai/DeepSeek-V3",
provider: "siliconflow",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});