diff --git a/ext/flan_t5.py b/ext/flan_t5.py index 8025c34..353e41a 100644 --- a/ext/flan_t5.py +++ b/ext/flan_t5.py @@ -7,7 +7,7 @@ class FlanT5(BaseLM): def __init__(self, model_name="google/flan-t5-base", temp=0.1, device='cuda', max_length=None, use_bf16=False, **kwargs): - super(FlanT5, self).__init__(name=model_name) + super(FlanT5, self).__init__(name=model_name, **kwargs) self.__device = device self.__max_length = 512 if max_length is None else max_length self.__model = T5ForConditionalGeneration.from_pretrained( diff --git a/ext/gemma.py b/ext/gemma.py index 5a32ddc..81d8418 100644 --- a/ext/gemma.py +++ b/ext/gemma.py @@ -7,8 +7,8 @@ class Gemma(BaseLM): def __init__(self, model_name="google/gemma-7b-it", temp=0.1, device='cuda', - max_length=None, api_token=None, use_bf16=False): - super(Gemma, self).__init__(name=model_name) + max_length=None, api_token=None, use_bf16=False, **kwargs): + super(Gemma, self).__init__(name=model_name, **kwargs) self.__device = device self.__max_length = 1024 if max_length is None else max_length self.__model = AutoModelForCausalLM.from_pretrained( diff --git a/ext/llama32.py b/ext/llama32.py index d9b099d..b94cf50 100644 --- a/ext/llama32.py +++ b/ext/llama32.py @@ -7,8 +7,8 @@ class Llama32(BaseLM): def __init__(self, model_name="meta-llama/Llama-3.2-3B-Instruct", api_token=None, - temp=0.1, device='cuda', max_length=256, use_bf16=False): - super(Llama32, self).__init__(name=model_name) + temp=0.1, device='cuda', max_length=256, use_bf16=False, **kwargs): + super(Llama32, self).__init__(name=model_name, **kwargs) if use_bf16: print("Warning: Experimental mode with bf-16!") diff --git a/ext/microsoft_phi_2.py b/ext/microsoft_phi_2.py index 3d5aae1..e66b849 100644 --- a/ext/microsoft_phi_2.py +++ b/ext/microsoft_phi_2.py @@ -11,8 +11,8 @@ class MicrosoftPhi2(BaseLM): """ https://huggingface.co/microsoft/phi-2 """ - def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False): - super(MicrosoftPhi2, self).__init__(model_name) + def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False, **kwargs): + super(MicrosoftPhi2, self).__init__(model_name, **kwargs) # Default parameters. kwargs = {"device_map": device, "trust_remote_code": True} diff --git a/ext/mistral.py b/ext/mistral.py index 20fe828..04d8ad9 100644 --- a/ext/mistral.py +++ b/ext/mistral.py @@ -7,9 +7,9 @@ class Mistral(BaseLM): def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.1", temp=0.1, device='cuda', max_length=None, - use_bf16=False): + use_bf16=False, **kwargs): assert(isinstance(max_length, int) or max_length is None) - super(Mistral, self).__init__(name=model_name) + super(Mistral, self).__init__(name=model_name, **kwargs) if use_bf16: print("Warning: Experimental mode with bf-16!") diff --git a/ext/openai_gpt.py b/ext/openai_gpt.py index f95d7a1..67ff413 100644 --- a/ext/openai_gpt.py +++ b/ext/openai_gpt.py @@ -5,9 +5,9 @@ class OpenAIGPT(BaseLM): def __init__(self, api_key, model_name="gpt-4-1106-preview", temp=0.1, max_tokens=None, assistant_prompt=None, - freq_penalty=0.0, kwargs=None): + freq_penalty=0.0, **kwargs): assert(isinstance(assistant_prompt, str) or assistant_prompt is None) - super(OpenAIGPT, self).__init__(name=model_name) + super(OpenAIGPT, self).__init__(name=model_name, **kwargs) # dynamic import of the OpenAI library. OpenAI = auto_import("openai._client.OpenAI", is_class=False)