Skip to content

Commit

Permalink
allow request option pass through (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi authored Apr 20, 2024
1 parent 224a0d1 commit 6942d65
Show file tree
Hide file tree
Showing 22 changed files with 268 additions and 96 deletions.
5 changes: 5 additions & 0 deletions .changeset/curly-ants-tie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

adding request option pass through + handling non validation errors a little bit better and not retrying if not validation error specifically
2 changes: 1 addition & 1 deletion docs/concepts/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ A follow-up meeting is scheduled for January 25th at 3 PM GMT to finalize the ag

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/action_items.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const extractActionItems = async (data: string): Promise<ActionItems | undefined
"content": `Create the action items for the following transcript: ${data}`,
},
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: ActionItemsSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/query_decomposition.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const createQueryPlan = async (question: string): Promise<QueryPlan | undefined>
"content": `Consider: ${question}\nGenerate the correct query plan.`,
},
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: QueryPlanSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/action_items/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const extractActionItems = async (data: string) => {
content: `Create the action items for the following transcript: ${data}`
}
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: ActionItemsSchema, name: "ActionItems" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/extract_user_stream/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ let extraction = {}

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
7 changes: 3 additions & 4 deletions examples/llm-validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ const openAi = new OpenAI({ apiKey: process.env.OPENAI_API_KEY ?? "" })

const instructor = Instructor({
client: openAi,
mode: "TOOLS",
debug: true
mode: "TOOLS"
})

const statement = "Do not say questionable things"
Expand All @@ -17,7 +16,7 @@ const QuestionAnswer = z.object({
question: z.string(),
answer: z.string().superRefine(
LLMValidator(instructor, statement, {
model: "gpt-4"
model: "gpt-4-turbo"
})
)
})
Expand All @@ -26,7 +25,7 @@ const question = "What is the meaning of life?"

const check = async (context: string) => {
return await instructor.chat.completions.create({
model: "gpt-3.5-turbo",
model: "gpt-4-turbo",
max_retries: 2,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down
2 changes: 1 addition & 1 deletion examples/query_decomposition/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const createQueryPlan = async (question: string) => {
content: `Consider: ${question}\nGenerate the correct query plan.`
}
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@instructor-ai/instructor",
"version": "1.1.2",
"version": "1.1.1",
"description": "structured outputs for llms",
"publishConfig": {
"access": "public"
Expand Down
26 changes: 17 additions & 9 deletions src/constants/providers.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { omit } from "@/lib"
import OpenAI from "openai"
import { z } from "zod"
import { MODE, withResponseModel, type Mode } from "zod-stream"
import { withResponseModel, MODE as ZMODE, type Mode } from "zod-stream"

export const MODE = ZMODE
export const PROVIDERS = {
OAI: "OAI",
ANYSCALE: "ANYSCALE",
Expand All @@ -11,7 +12,6 @@ export const PROVIDERS = {
GROQ: "GROQ",
OTHER: "OTHER"
} as const

export type Provider = keyof typeof PROVIDERS

export const PROVIDER_SUPPORTED_MODES: {
Expand All @@ -34,6 +34,19 @@ export const NON_OAI_PROVIDER_URLS = {
} as const

export const PROVIDER_PARAMS_TRANSFORMERS = {
[PROVIDERS.GROQ]: {
[MODE.TOOLS]: function groqToolsParamsTransformer<
T extends z.AnyZodObject,
P extends OpenAI.ChatCompletionCreateParams
>(params: ReturnType<typeof withResponseModel<T, "TOOLS", P>>) {
if (params.tools.some(tool => tool) && params.stream) {
console.warn("Streaming may not be supported when using tools in Groq, try MD_JSON instead")
return params
}

return params
}
},
[PROVIDERS.ANYSCALE]: {
[MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema<
T extends z.AnyZodObject,
Expand Down Expand Up @@ -90,12 +103,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[PROVIDERS.OAI]: {
[MODE.FUNCTIONS]: ["*"],
[MODE.TOOLS]: ["*"],
[MODE.JSON]: [
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview"
],
[MODE.JSON]: ["gpt-3.5-turbo-1106", "gpt-4-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview"],
[MODE.MD_JSON]: ["*"]
},
[PROVIDERS.TOGETHER]: {
Expand Down Expand Up @@ -124,7 +132,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[MODE.TOOLS]: ["*"]
},
[PROVIDERS.GROQ]: {
[MODE.TOOLS]: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"],
[MODE.TOOLS]: ["mixtral-8x7b-32768", "gemma-7b-it"],
[MODE.MD_JSON]: ["*"]
}
}
8 changes: 1 addition & 7 deletions src/dsl/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ export const LLMValidator = <C extends GenericClient | OpenAI>(
}
}

export const moderationValidator = <C extends GenericClient | OpenAI>(
client: InstructorClient<C>
) => {
export const moderationValidator = (client: InstructorClient<OpenAI>) => {
return async (value: string, ctx: z.RefinementCtx) => {
try {
if (!(client instanceof OpenAI)) {
throw new Error("ModerationValidator only supports OpenAI clients")
}

const response = await client.moderations.create({ input: value })
const flaggedResults = response.results.filter(result => result.flagged)

Expand Down
89 changes: 60 additions & 29 deletions src/instructor.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
ChatCompletionCreateParamsWithModel,
ClientTypeChatCompletionRequestOptions,
GenericChatCompletion,
GenericClient,
InstructorConfig,
Expand All @@ -8,7 +9,7 @@ import {
ReturnTypeBasedOnParams
} from "@/types"
import OpenAI from "openai"
import { z } from "zod"
import { z, ZodError } from "zod"
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
import { fromZodError } from "zod-validation-error"

Expand Down Expand Up @@ -102,11 +103,14 @@ class Instructor<C extends GenericClient | OpenAI> {
}
}

private async chatCompletionStandard<T extends z.AnyZodObject>({
max_retries = MAX_RETRIES_DEFAULT,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>): Promise<z.infer<T>> {
private async chatCompletionStandard<T extends z.AnyZodObject>(
{
max_retries = MAX_RETRIES_DEFAULT,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<z.infer<T>> {
let attempts = 0
let validationIssues = ""
let lastMessage: OpenAI.ChatCompletionMessageParam | null = null
Expand Down Expand Up @@ -147,13 +151,17 @@ class Instructor<C extends GenericClient | OpenAI> {

try {
if (this.client.chat?.completions?.create) {
const result = await this.client.chat.completions.create({
...resolvedParams,
stream: false
})
const result = await this.client.chat.completions.create(
{
...resolvedParams,
stream: false
},
requestOptions
)

completion = result as GenericChatCompletion<typeof result>
} else {
throw new Error("Unsupported client type")
throw new Error("Unsupported client type -- no completion method found.")
}
this.log("debug", "raw standard completion response: ", completion)
} catch (error) {
Expand All @@ -176,7 +184,17 @@ class Instructor<C extends GenericClient | OpenAI> {
const data = JSON.parse(parsedCompletion) as z.infer<T> & { _meta?: CompletionMeta }
return { ...data, _meta: { usage: completion?.usage ?? undefined } }
} catch (error) {
this.log("error", "failed to parse completion", parsedCompletion, this.mode)
this.log(
"error",
"failed to parse completion",
parsedCompletion,
this.mode,
"attempt: ",
attempts,
"max attempts: ",
max_retries
)

throw error
}
}
Expand All @@ -202,26 +220,38 @@ class Instructor<C extends GenericClient | OpenAI> {
throw new Error("Validation failed.")
}
}

return validation.data
} catch (error) {
if (!(error instanceof ZodError)) {
throw error
}

if (attempts < max_retries) {
this.log(
"debug",
`response model: ${response_model.name} - Retrying, attempt: `,
attempts
)

this.log(
"warn",
`response model: ${response_model.name} - Validation issues: `,
validationIssues
validationIssues,
" - Attempt: ",
attempts,
" - Max attempts: ",
max_retries
)

attempts++
return await makeCompletionCallWithRetries()
} else {
this.log(
"debug",
`response model: ${response_model.name} - Max attempts reached: ${attempts}`
)

this.log(
"error",
`response model: ${response_model.name} - Validation issues: `,
Expand All @@ -236,13 +266,10 @@ class Instructor<C extends GenericClient | OpenAI> {
return makeCompletionCallWithRetries()
}

private async chatCompletionStream<T extends z.AnyZodObject>({
max_retries,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>): Promise<
AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>
> {
private async chatCompletionStream<T extends z.AnyZodObject>(
{ max_retries, response_model, ...params }: ChatCompletionCreateParamsWithModel<T>,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>> {
if (max_retries) {
this.log("warn", "max_retries is not supported for streaming completions")
}
Expand All @@ -269,10 +296,13 @@ class Instructor<C extends GenericClient | OpenAI> {
return streamClient.create({
completionPromise: async () => {
if (this.client.chat?.completions?.create) {
const completion = await this.client.chat.completions.create({
...completionParams,
stream: true
})
const completion = await this.client.chat.completions.create(
{
...completionParams,
stream: true
},
requestOptions
)

this.log("debug", "raw stream completion response: ", completion)

Expand Down Expand Up @@ -306,18 +336,19 @@ class Instructor<C extends GenericClient | OpenAI> {
P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel<T>
: ClientTypeChatCompletionParams<OpenAILikeClient<C>> & { response_model: never }
>(
params: P
params: P,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<ReturnTypeBasedOnParams<typeof this.client, P>> => {
this.validateModelModeSupport(params)

if (this.isChatCompletionCreateParamsWithModel(params)) {
if (params.stream) {
return this.chatCompletionStream(params) as ReturnTypeBasedOnParams<
return this.chatCompletionStream(params, requestOptions) as ReturnTypeBasedOnParams<
typeof this.client,
P & { stream: true }
>
} else {
return this.chatCompletionStandard(params) as ReturnTypeBasedOnParams<
return this.chatCompletionStandard(params, requestOptions) as ReturnTypeBasedOnParams<
typeof this.client,
P
>
Expand All @@ -326,8 +357,8 @@ class Instructor<C extends GenericClient | OpenAI> {
if (this.client.chat?.completions?.create) {
const result =
this.isStandardStream(params) ?
await this.client.chat.completions.create(params)
: await this.client.chat.completions.create(params)
await this.client.chat.completions.create(params, requestOptions)
: await this.client.chat.completions.create(params, requestOptions)

return result as unknown as ReturnTypeBasedOnParams<OpenAILikeClient<C>, P>
} else {
Expand Down
9 changes: 7 additions & 2 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export type GenericCreateParams<M = unknown> = Omit<
[key: string]: unknown
}

export type GenericRequestOptions = Partial<OpenAI.RequestOptions> & {
[key: string]: unknown
}

export type GenericChatCompletion<T = unknown> = Partial<OpenAI.Chat.Completions.ChatCompletion> & {
[key: string]: unknown
choices?: T
Expand All @@ -43,15 +47,16 @@ export type GenericClient = {
export type ClientTypeChatCompletionParams<C> =
C extends OpenAI ? OpenAI.ChatCompletionCreateParams : GenericCreateParams

export type ClientTypeChatCompletionRequestOptions<C> =
C extends OpenAI ? OpenAI.RequestOptions : GenericRequestOptions

export type ClientType<C> =
C extends OpenAI ? "openai"
: C extends GenericClient ? "generic"
: never

export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient

export type SupportedInstructorClient = GenericClient | OpenAI

export type LogLevel = "debug" | "info" | "warn" | "error"

export type CompletionMeta = Partial<ZCompletionMeta> & {
Expand Down
Loading

0 comments on commit 6942d65

Please sign in to comment.