Skip to content

Commit

Permalink
Non-oai usage meta + non-oai client types (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi authored Jun 13, 2024
1 parent f386ad7 commit 0a5bbd8
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 27 deletions.
5 changes: 5 additions & 0 deletions .changeset/light-chefs-clean.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

update client types to better support non oai clients + updates to allow for passing usage properties into meta from non-oai clients
Binary file modified bun.lockb
Binary file not shown.
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"zod": ">=3.22.4"
},
"devDependencies": {
"@anthropic-ai/sdk": "latest",
"@anthropic-ai/sdk": "0.22.0",
"@changesets/changelog-github": "^0.5.0",
"@changesets/cli": "^2.27.1",
"@ianvs/prettier-plugin-sort-imports": "4.1.0",
Expand All @@ -75,8 +75,8 @@
"eslint-plugin-only-warn": "^1.1.0",
"eslint-plugin-prettier": "^5.1.2",
"husky": "^8.0.3",
"llm-polyglot": "1.0.0",
"openai": "latest",
"llm-polyglot": "2.0.0",
"openai": "4.50.0",
"prettier": "latest",
"ts-inference-check": "^0.3.0",
"tsup": "^8.0.1",
Expand Down
57 changes: 50 additions & 7 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ import {
PROVIDER_SUPPORTED_MODES_BY_MODEL,
PROVIDERS
} from "./constants/providers"
import { iterableTee } from "./lib"
import { ClientTypeChatCompletionParams, CompletionMeta } from "./types"

const MAX_RETRIES_DEFAULT = 0

class Instructor<C extends GenericClient | OpenAI> {
class Instructor<C> {
readonly client: OpenAILikeClient<C>
readonly mode: Mode
readonly provider: Provider
Expand All @@ -46,7 +47,17 @@ class Instructor<C extends GenericClient | OpenAI> {
logger = undefined,
retryAllErrors = false
}: InstructorConfig<C>) {
this.client = client
if (!isGenericClient(client) && !(client instanceof OpenAI)) {
throw new Error("Client does not match the required structure")
}

if (client instanceof OpenAI) {
this.client = client as OpenAI
} else {
this.client = client as C & GenericClient
}

// this.client = client
this.mode = mode
this.debug = debug
this.retryAllErrors = retryAllErrors
Expand Down Expand Up @@ -308,7 +319,9 @@ class Instructor<C extends GenericClient | OpenAI> {
debug: this.debug ?? false
})

async function checkForUsage(reader: Stream<OpenAI.ChatCompletionChunk>) {
async function checkForUsage(
reader: Stream<OpenAI.ChatCompletionChunk> | AsyncIterable<OpenAI.ChatCompletionChunk>
) {
for await (const chunk of reader) {
if ("usage" in chunk) {
streamUsage = chunk.usage as CompletionMeta["usage"]
Expand Down Expand Up @@ -345,6 +358,24 @@ class Instructor<C extends GenericClient | OpenAI> {
})
}

//check if async iterator
if (
this.provider !== "OAI" &&
completionParams?.stream &&
completion?.[Symbol.asyncIterator]
) {
const [completion1, completion2] = await iterableTee(
completion as AsyncIterable<OpenAI.ChatCompletionChunk>,
2
)

checkForUsage(completion1)

return OAIStream({
res: completion2
})
}

return OAIStream({
res: completion as unknown as AsyncIterable<OpenAI.ChatCompletionChunk>
})
Expand Down Expand Up @@ -419,7 +450,7 @@ class Instructor<C extends GenericClient | OpenAI> {
}
}

export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> & OpenAILikeClient<C>
export type InstructorClient<C> = Instructor<C> & OpenAILikeClient<C>

/**
* Creates an instance of the `Instructor` class.
Expand All @@ -442,9 +473,7 @@ export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> &
* @param args
* @returns
*/
export default function createInstructor<C extends GenericClient | OpenAI>(
args: InstructorConfig<C>
): InstructorClient<C> {
export default function createInstructor<C>(args: InstructorConfig<C>): InstructorClient<C> {
const instructor = new Instructor<C>(args)
const instructorWithProxy = new Proxy(instructor, {
get: (target, prop, receiver) => {
Expand All @@ -458,3 +487,17 @@ export default function createInstructor<C extends GenericClient | OpenAI>(

return instructorWithProxy as InstructorClient<C>
}
//eslint-disable-next-line @typescript-eslint/no-explicit-any
function isGenericClient(client: any): client is GenericClient {
return (
typeof client === "object" &&
client !== null &&
"baseURL" in client &&
"chat" in client &&
typeof client.chat === "object" &&
"completions" in client.chat &&
typeof client.chat.completions === "object" &&
"create" in client.chat.completions &&
typeof client.chat.completions.create === "function"
)
}
42 changes: 42 additions & 0 deletions src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,45 @@ export function omit<T extends object, K extends keyof T>(keys: K[], obj: T): Om
}
return result
}

export async function iterableTee<T>(
iterable: AsyncIterable<T>,
n: number
): Promise<AsyncGenerator<T>[]> {
const buffers: T[][] = Array.from({ length: n }, () => [])
const resolvers: (() => void)[] = []
const iterator = iterable[Symbol.asyncIterator]()
let done = false

async function* reader(index: number) {
while (true) {
if (buffers[index].length > 0) {
yield buffers[index].shift()!
} else if (done) {
break
} else {
await new Promise<void>(resolve => resolvers.push(resolve))
}
}
}

;(async () => {
for await (const item of {
[Symbol.asyncIterator]: () => iterator
}) {
for (const buffer of buffers) {
buffer.push(item)
}

while (resolvers.length > 0) {
resolvers.shift()!()
}
}
done = true
while (resolvers.length > 0) {
resolvers.shift()!()
}
})()

return Array.from({ length: n }, (_, i) => reader(i))
}
6 changes: 3 additions & 3 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export type GenericClient = {
baseURL?: string
chat?: {
completions?: {
create?: (params: GenericCreateParams) => Promise<unknown>
create?: <P extends GenericCreateParams>(params: P) => Promise<unknown>
}
}
}
Expand All @@ -55,7 +55,7 @@ export type ClientType<C> =
: C extends GenericClient ? "generic"
: never

export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient
export type OpenAILikeClient<C> = OpenAI | (C & GenericClient)
export type SupportedInstructorClient = GenericClient | OpenAI
export type LogLevel = "debug" | "info" | "warn" | "error"

Expand All @@ -68,7 +68,7 @@ export type Mode = ZMode
export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>

export interface InstructorConfig<C> {
client: OpenAILikeClient<C>
client: C
mode: Mode
debug?: boolean
logger?: <T extends unknown[]>(level: LogLevel, ...args: T) => void
Expand Down
33 changes: 20 additions & 13 deletions tests/anthropic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ describe("LLMClient Anthropic Provider - mode: TOOLS", () => {
})
})

describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
describe("LLMClient Anthropic Provider - mode: TOOLS - stream", () => {
const instructor = Instructor({
client: anthropicClient,
mode: "MD_JSON"
mode: "TOOLS"
})

test("basic completion", async () => {
const completion = await instructor.chat.completions.create({
model: "claude-3-sonnet-20240229",
stream: true,
max_tokens: 1000,
messages: [
{
Expand All @@ -135,17 +136,24 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
}
],
response_model: {
name: "get_name",
name: "extract_name",
schema: z.object({
name: z.string()
})
}
})

expect(omit(["_meta"], completion)).toEqual({ name: "Dimitri Kennedy" })
let final = {}

for await (const result of completion) {
final = result
}

//@ts-expect-error ignore for testing
expect(omit(["_meta"], final)).toEqual({ name: "Dimitri Kennedy" })
})

test("complex schema - streaming", async () => {
test("complex schema", async () => {
const completion = await instructor.chat.completions.create({
model: "claude-3-sonnet-20240229",
max_tokens: 1000,
Expand Down Expand Up @@ -173,14 +181,15 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
Programming
Leadership
Communication
`
}
],
response_model: {
name: "process_user_data",
schema: z.object({
story: z
.string()
.describe("A long and mostly made up story about the user - minimum 500 words"),
userDetails: z.object({
firstName: z.string(),
lastName: z.string(),
Expand All @@ -196,21 +205,19 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
years: z.number().optional()
})
),
skills: z.array(z.string()),
summaryOfWorldWarOne: z
.string()
.describe("A detailed summary of World War One and its major events - min 500 words")
skills: z.array(z.string())
})
}
})

let final = {}

for await (const result of completion) {
final = result
}

//@ts-expect-error - lazy
expect(omit(["_meta", "summaryOfWorldWarOne"], final)).toEqual({
//@ts-expect-error ignore for testing
expect(omit(["_meta", "story"], final)).toEqual({
userDetails: {
firstName: "John",
lastName: "Doe",
Expand Down
1 change: 0 additions & 1 deletion tests/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ async function extractUser() {
let extraction: Extraction = {}

for await (const result of extractionStream) {
console.log(result)
try {
extraction = result
expect(result).toHaveProperty("users")
Expand Down

0 comments on commit 0a5bbd8

Please sign in to comment.