-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmerge_and_push_model.py
33 lines (28 loc) · 1.95 KB
/
merge_and_push_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Reference:
# https://discuss.huggingface.co/t/further-finetuning-a-lora-finetuned-causallm-model/3698 7/4
# https://github.com/TrelisResearch/install-guides/blob/main/Pushing_to_Hub.ipynb
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from huggingface_hub import login
from config.config import *
best_checkpoint_path = os.path.join(FINE_TUNED_MODELS_PATH, BASE_MODEL_NAME + "_checkpoints", "checkpoint-10200")
merged_model_path = os.path.join(FINE_TUNED_MODELS_PATH, MERGED_MODEL_NAME)
if os.path.exists(merged_model_path):
print(f"Load previous merged model from {merged_model_path}.")
model_to_push = AutoModelForCausalLM.from_pretrained(merged_model_path, torch_dtype="auto", device_map=DEVICE)
else:
print(f"Using {best_checkpoint_path} to merge model.")
base_with_adapters_model = AutoPeftModelForCausalLM.from_pretrained(best_checkpoint_path, torch_dtype="auto", device_map=DEVICE)
## Or use the following code to load the base model and the adapter separately:
# base_model_path = base_model_path = os.path.join(BASE_MODELS_PATH, BASE_MODEL_NAME)
# base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="auto", device_map=DEVICE)
# model_to_push = PeftModel.from_pretrained(base_model,best_checkpoint_path) # Apply the desired adapter to the base model
model_to_push = base_with_adapters_model.merge_and_unload() # merge adapters with the base model
model_to_push.save_pretrained(merged_model_path) # Save the merged model locally
base_model_path = os.path.join(BASE_MODELS_PATH, BASE_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.save_pretrained(merged_model_path) # Save the tokenizer from the base model, it's necessary for the model to work
# 需要科学上网环境
login(HUGGING_FACE_TOKEN)
model_to_push.push_to_hub(MERGED_MODEL_NAME, token=True, max_shard_size="5GB", safe_serialization=True)