diff --git a/packages/inference/README.md b/packages/inference/README.md index 21e46625b..3e1acdfd6 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -62,6 +62,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [CentML](https://centml.ai) - [Groq](https://groq.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. @@ -95,6 +96,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [CentML supported models](https://huggingface.co/api/partners/centml/models) - [Groq supported models](https://console.groq.com/docs/models) - [Novita AI supported models](https://huggingface.co/api/partners/novita/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 9afcb8980..a0519e884 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,5 +1,6 @@ import * as BlackForestLabs from "../providers/black-forest-labs.js"; import * as Cerebras from "../providers/cerebras.js"; +import * as CentML from "../providers/centml"; import * as Cohere from "../providers/cohere.js"; import * as FalAI from "../providers/fal-ai.js"; import * as FeatherlessAI from "../providers/featherless-ai.js"; @@ -12,6 +13,7 @@ import * as Novita from "../providers/novita.js"; import * as Nscale from "../providers/nscale.js"; import * as OpenAI from "../providers/openai.js"; import * as OvhCloud from "../providers/ovhcloud.js"; + import type { AudioClassificationTaskHelper, AudioToAudioTaskHelper, @@ -56,6 +58,9 @@ export const PROVIDERS: Record | "conversational"; export const INFERENCE_PROVIDERS = [ "black-forest-labs", "cerebras", + "centml", "cohere", "fal-ai", "featherless-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 602e034cd..2ddab71a1 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2023,4 +2023,89 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "CentML", + () => { + const client = new InferenceClient(env.HF_CENTML_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["centml"] = { + "meta-llama/Llama-3.2-3B-Instruct": { + hfModelId: "meta-llama/Llama-3.2-3B-Instruct", + providerId: "meta-llama/Llama-3.2-3B-Instruct", + status: "live", + task: "conversational", + }, + }; + + describe("chat completions", () => { + it("basic chat completion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chat completion with multiple messages", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is 2+2?" }, + { role: "assistant", content: "The answer is 4." }, + { role: "user", content: "What is 3+3?" }, + ], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("6"); + } + }); + + it("chat completion with parameters", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Write a short poem about AI" }], + temperature: 0.7, + max_tokens: 100, + top_p: 0.9, + }); + if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) { + const completion = res.choices[0].message.content; + expect(completion).toBeTruthy(); + expect(completion.length).toBeGreaterThan(0); + } + }); + + it("chat completion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }); + }, + TIMEOUT + ); });