Skip to content

Commit

Permalink
Update models API
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 26, 2024
1 parent 1807f98 commit 322169a
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ext/flan_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class FlanT5(BaseLM):

def __init__(self, model_name="google/flan-t5-base", temp=0.1, device='cuda', max_length=None, use_bf16=False, **kwargs):
super(FlanT5, self).__init__(name=model_name)
super(FlanT5, self).__init__(name=model_name, **kwargs)
self.__device = device
self.__max_length = 512 if max_length is None else max_length
self.__model = T5ForConditionalGeneration.from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions ext/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class Gemma(BaseLM):

def __init__(self, model_name="google/gemma-7b-it", temp=0.1, device='cuda',
max_length=None, api_token=None, use_bf16=False):
super(Gemma, self).__init__(name=model_name)
max_length=None, api_token=None, use_bf16=False, **kwargs):
super(Gemma, self).__init__(name=model_name, **kwargs)
self.__device = device
self.__max_length = 1024 if max_length is None else max_length
self.__model = AutoModelForCausalLM.from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions ext/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class Llama32(BaseLM):

def __init__(self, model_name="meta-llama/Llama-3.2-3B-Instruct", api_token=None,
temp=0.1, device='cuda', max_length=256, use_bf16=False):
super(Llama32, self).__init__(name=model_name)
temp=0.1, device='cuda', max_length=256, use_bf16=False, **kwargs):
super(Llama32, self).__init__(name=model_name, **kwargs)

if use_bf16:
print("Warning: Experimental mode with bf-16!")
Expand Down
4 changes: 2 additions & 2 deletions ext/microsoft_phi_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class MicrosoftPhi2(BaseLM):
""" https://huggingface.co/microsoft/phi-2
"""

def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False):
super(MicrosoftPhi2, self).__init__(model_name)
def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False, **kwargs):
super(MicrosoftPhi2, self).__init__(model_name, **kwargs)

# Default parameters.
kwargs = {"device_map": device, "trust_remote_code": True}
Expand Down
4 changes: 2 additions & 2 deletions ext/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
class Mistral(BaseLM):

def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.1", temp=0.1, device='cuda', max_length=None,
use_bf16=False):
use_bf16=False, **kwargs):
assert(isinstance(max_length, int) or max_length is None)
super(Mistral, self).__init__(name=model_name)
super(Mistral, self).__init__(name=model_name, **kwargs)

if use_bf16:
print("Warning: Experimental mode with bf-16!")
Expand Down
4 changes: 2 additions & 2 deletions ext/openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
class OpenAIGPT(BaseLM):

def __init__(self, api_key, model_name="gpt-4-1106-preview", temp=0.1, max_tokens=None, assistant_prompt=None,
freq_penalty=0.0, kwargs=None):
freq_penalty=0.0, **kwargs):
assert(isinstance(assistant_prompt, str) or assistant_prompt is None)
super(OpenAIGPT, self).__init__(name=model_name)
super(OpenAIGPT, self).__init__(name=model_name, **kwargs)

# dynamic import of the OpenAI library.
OpenAI = auto_import("openai._client.OpenAI", is_class=False)
Expand Down

0 comments on commit 322169a

Please sign in to comment.