-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRunChatModel.py
59 lines (56 loc) · 1.65 KB
/
RunChatModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import openai
OPENAI_VER_MAJ = int(openai.__version__.split(".")[0])
if OPENAI_VER_MAJ >= 1:
from openai import APIError, AuthenticationError, APIConnectionError
from pydantic import BaseModel as CompletionObject
else:
from openai.error import APIError, AuthenticationError, APIConnectionError
from openai.openai_object import OpenAIObject as CompletionObject
def run_chat_completion(
model_name,
messages,
token,
endpoint,
max_tokens=300,
n=1,
stream=False,
stop=None,
temperature=0.0,
top_p=1.0,
frequency_penalty=0,
presence_penalty=0
):
openai.api_key = token
if OPENAI_VER_MAJ > 0:
openai.base_url = endpoint + "/v1"
client = openai.OpenAI(api_key=token)
completion = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
stream=stream,
n=n,
stop=stop,
top_p=top_p,
temperature=temperature,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
else:
openai.api_base = endpoint + "/v1"
completion = openai.ChatCompletion.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
stream=stream,
n=n,
stop=stop,
top_p=top_p,
temperature=temperature,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
if OPENAI_VER_MAJ >= 1:
return completion.model_dump(exclude_unset=True)
else:
return completion