Skip to content

Commit

Permalink
python
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Jan 24, 2025
1 parent a604bb7 commit 3646224
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 55 deletions.
2 changes: 2 additions & 0 deletions packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/infer

/**
* URL to set as baseUrl in the OpenAI SDK.
*
* TODO(Expose this from HfInference in the future?)
*/
export function openAIbaseUrl(provider: InferenceProvider): string {
return provider === "hf-inference"
Expand Down
149 changes: 94 additions & 55 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
foobar = 3
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
`;

export const snippetConversational = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider,
opts?: {
streaming?: boolean;
messages?: ChatCompletionInputMessage[];
Expand Down Expand Up @@ -119,19 +121,25 @@ print(completion.choices[0].message)`,
}
};

export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`,
});
},
];
};

export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(data):
export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
Expand All @@ -145,38 +153,53 @@ output = query({
"image_path": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
})`,
});
},
];
};

export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
})`,
});

export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()
output = query(${getModelInputSnippet(model)})`,
});

export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [
{
client: "huggingface_hub",
content: `${snippetImportInferenceClient(model, accessToken)}
},
];
};

export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()
output = query(${getModelInputSnippet(model)})`,
},
];
};

export const snippetTextToImage = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
return [
{
client: "huggingface_hub",
content: `${snippetImportInferenceClient(model, accessToken)}
# output is a PIL.Image object
image = client.text_to_image(${getModelInputSnippet(model)})`,
},
{
client: "requests",
content: `def query(payload):
},
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
image_bytes = query({
Expand All @@ -187,25 +210,35 @@ image_bytes = query({
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`,
},
];
},
];
};

export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
response = query({
"inputs": {"data": ${getModelInputSnippet(model)}},
})`,
});
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
response = query({
"inputs": {"data": ${getModelInputSnippet(model)}},
})`,
},
];
};

export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet => {
export const snippetTextToAudio = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
if (model.library_name === "transformers") {
return {
content: `def query(payload):
return [
{
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
Expand All @@ -215,10 +248,12 @@ audio_bytes = query({
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio_bytes)`,
};
},
];
} else {
return {
content: `def query(payload):
return [
{
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -228,12 +263,15 @@ audio, sampling_rate = query({
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio, rate=sampling_rate)`,
};
},
];
}
};

export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
content: `def query(payload):
with open(payload["image"], "rb") as f:
img = f.read()
payload["image"] = base64.b64encode(img).decode("utf-8")
Expand All @@ -243,16 +281,19 @@ export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): Infer
output = query({
"inputs": ${getModelInputSnippet(model)},
})`,
});
},
];
};

export const pythonSnippets: Partial<
Record<
PipelineType,
(
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider,
opts?: Record<string, unknown>
) => InferenceSnippet | InferenceSnippet[]
) => InferenceSnippet[]
>
> = {
// Same order as in tasks/src/pipelines.ts
Expand Down Expand Up @@ -293,15 +334,13 @@ export function getPythonInferenceSnippet(
): InferenceSnippet[] {
if (model.tags.includes("conversational")) {
// Conversational model detected, so we display a code snippet that features the Messages API
return snippetConversational(model, accessToken, opts);
return snippetConversational(model, accessToken, provider, opts);
} else {
let snippets =
const snippets =
model.pipeline_tag && model.pipeline_tag in pythonSnippets
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? [{ content: "" }]
? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? [{ content: "" }]
: [{ content: "" }];

snippets = Array.isArray(snippets) ? snippets : [snippets];

return snippets.map((snippet) => {
return {
...snippet,
Expand Down

0 comments on commit 3646224

Please sign in to comment.