Skip to content

Commit

Permalink
Code Llama Fine-tuning Support (#194)
Browse files Browse the repository at this point in the history
* Print statements for debugging and initial support for Code Llama

* Added multiple print statements for debugging fine tuning
* Added support for Code Llama 7b
* Depending on the training parameters I set I either get an out of memory GPU error or ValueError(“optimizer got an empty parameter list”)

* Code Llama fine-tuning but fails on checkpoint

* commenting print statements

* updating default config behavior

* Begin adding encoding for Code Llama

* adding BOS and EOS tokens for Code Llama, model running properly

* getting rid of #?

* Print statements for debugging and initial support for Code Llama

* Added multiple print statements for debugging fine tuning
* Added support for Code Llama 7b
* Depending on the training parameters I set I either get an out of memory GPU error or ValueError(“optimizer got an empty parameter list”)

* Code Llama fine-tuning but fails on checkpoint

* commenting print statements

* updating default config behavior

* Begin adding encoding for Code Llama

* adding BOS and EOS tokens for Code Llama, model running properly

* getting rid of #?
  • Loading branch information
adam-weinberger authored and mitya52 committed Nov 17, 2023
1 parent c83fcc9 commit 4d6c35b
Show file tree
Hide file tree
Showing 10 changed files with 93,539 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ dmypy.json
.idea/

### VisualStudioCode ###
.vscode
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
Expand Down
2 changes: 1 addition & 1 deletion known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
},
"required_memory_mb": 14000,
"T": 2048,
"filter_caps": ["completion"],
"filter_caps": ["completion", "finetune"],
},
"wizardlm/30b": {
"backend": "transformers",
Expand Down
8 changes: 8 additions & 0 deletions refact_data_pipeline/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def set_lora_dropout(self, dropout: float) -> 'ConfigBuilder':
def set_lora_init_scale(self, init_scale: float) -> 'ConfigBuilder':
self.cfg['model_info']['lora']['lora_init_scale'] = init_scale
return self

def set_freeze_exceptions(self, exceptions: List[str]) -> 'ConfigBuilder':
self.cfg['model_info']['freeze_exceptions'] = exceptions
return self

def set_save_every(self, save_every: int) -> 'ConfigBuilder':
self.cfg['save_every'] = save_every
return self

def set_limit_time_seconds(self, seconds: int) -> 'ConfigBuilder':
self.cfg['limit_time_seconds'] = seconds
Expand Down
15 changes: 13 additions & 2 deletions refact_data_pipeline/finetune/finetune_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def _get_ds_len_per_epoch(model_name, cfg_builder):
.set_schedule_by_heuristics(ds_len=ds_len)
.set_low_gpu_mem_mode_by_heuristics())
else:
traces.log('Not using heuristics')
traces.log('low_gpu_mem_mode: %s' % user_cfg['low_gpu_mem_mode'])
(cfg_builder
.set_train_steps(user_cfg['train_steps'])
.set_lr_decay_steps(user_cfg['lr_decay_steps'])
Expand All @@ -116,7 +118,10 @@ def _get_ds_len_per_epoch(model_name, cfg_builder):
.set_batch_size(user_cfg['batch_size'])
.set_warmup_steps(user_cfg['warmup_num_steps'])
.set_limit_time_seconds(user_cfg['limit_time_seconds'])
.set_weight_decay(user_cfg['weight_decay']))
.set_weight_decay(user_cfg['weight_decay'])
.set_lora_target_modules(user_cfg['lora_target_modules'])
.set_freeze_exceptions(user_cfg['freeze_exceptions'])
.set_save_every(user_cfg['save_every']))

traces.log(f'Freeze exceptions: {cfg_builder.cfg["model_info"]["freeze_exceptions"]}')
for k, v in cfg_builder.cfg["model_info"]["lora"].items():
Expand All @@ -143,6 +148,7 @@ def create_data(model_name, cfg, enc) -> Tuple[Any, Optional[Any]]:
test_pipe = getattr(finetune_datasource, model_config["test_ds_pipeline"]["pipeline_name"])

train_ds = train_pipe(filtered_train, train_dataopts)
traces.log('train batch size: %s' % cfg['train_batch_size'])
train_ds = BatchIterator(train_ds, dataopts=dict(
batch_size=cfg['train_batch_size'],
drop_last=True
Expand Down Expand Up @@ -182,6 +188,7 @@ def _save_checkpoint(force: bool = False):
tag = "iter%04d-trainloss%0.3f" % (iter_n, progress["loss"])
traces.log("saving checkpoint %s" % tag)
save_model_state(model, save_path=save_path, tag=tag)
traces.log("finished saving checkpoint %s" % tag)

model_config = supported_models.config[model_name]
save_path = os.path.join(traces.context().path, "checkpoints")
Expand Down Expand Up @@ -220,7 +227,9 @@ def _save_checkpoint(force: bool = False):
f"({batch['mask'].sum()}/{batch['mask'].numel()})"
)

print("train batch size: %s" % cfg.get("train_batch_size"))
for b0 in range(0, cfg.get("train_batch_size"), cfg.get("micro_batch_size")):

try:
input = batch['input'][b0:b0 + micro_bs].contiguous()
logits = forward(input=input, low_gpu_mem_mode=low_gpu_mem_mode)
Expand Down Expand Up @@ -309,7 +318,7 @@ def _save_checkpoint(force: bool = False):
def finetune(status_dict, models_db: Dict[str, Any]):
logging.info("starting finetune at %s" % traces.context().path)
cfg = load_finetune_config(models_db)
traces.log("creating model...")
traces.log("Creating model %s..." % cfg['model_name'])
t0 = time.time()
model = make_model(
model_name=cfg['model_name'],
Expand All @@ -328,8 +337,10 @@ def finetune(status_dict, models_db: Dict[str, Any]):
)
t1 = time.time()
traces.log("/model %0.1fms" % ((t1 - t0) * 1000))
traces.log(cfg)
if cfg['debug']:
summary(model, depth=4, col_names=['num_params', 'params_percent', 'trainable'])

model, optimizer, _, _ = deepspeed.initialize(
config=cfg,
model=model,
Expand Down
5 changes: 4 additions & 1 deletion refact_data_pipeline/finetune/finetune_train_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
"lora_alpha": 32,
"lora_init_scale": 0.01,
"lora_dropout": 0.01,
"low_gpu_mem_mode": True
"low_gpu_mem_mode": True,
"lora_target_modules": ["qkv", "out", "mlp"],
"freeze_exceptions": ["wte", "lm_head", "lora"],
"save_every": 10
}
7 changes: 5 additions & 2 deletions refact_data_pipeline/finetune/model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def save_model_state(model, save_path, tag):
'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size',
'ds_config', 'ds_version'
}

model.save_checkpoint(save_path, tag=tag)
model.save_checkpoint(save_dir=save_path, tag=tag)
cp_path = Path(save_path) / tag
model_cps = [p for p in cp_path.iterdir() if 'model_states' in p.name]
_ = [p.unlink() for p in cp_path.iterdir() if 'model_states' not in p.name]
Expand Down Expand Up @@ -141,6 +141,9 @@ def setup_encoding(
encoding.INFILL = model_config["tokenizer"]["fim_middle"]
encoding.SUFFIX = model_config["tokenizer"]["fim_suffix"]
encoding.ESCAPE = model_config["tokenizer"]["escape"]
encoding.BOS = model_config["tokenizer"]["bos"] if model_config["tokenizer"]["bos"] else ""
encoding.EOS = model_config["tokenizer"]["eos"] if model_config["tokenizer"]["eos"] else ""

return encoding


Expand Down
40 changes: 40 additions & 0 deletions refact_data_pipeline/finetune/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,45 @@
},
"train_model_modifiers": [],
"force_enable_checkpointing": True
},

"codellama/7b": {
"lora_target_modules_mapping": {
"qkv": ["attn.q_attn", "attn.c_attn"],
"out": ["attn.c_proj"],
"backproj": ["attn.c_proj"],
"mlp": ["mlp.c_fc", "mlp.c_proj"],
},
"freeze_exceptions_mapping": {
"wte": "wte",
"lm_head": "lm_head",
"lora": "lora"
},
"tokenizer": {
"eot_idx": 6,
"padding_idx": 7,
"fim_prefix": 3,
"fim_middle": 4,
"fim_suffix": 5,
"escape": 7,
"bos": 1,
"eos": 2
},
"train_ds_pipeline": {
"ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1,"
"tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0,"
"random_trim_context_prob=0.01,fim_random_seed=42",
"pipeline_name": "local_fim"
},
"test_ds_pipeline": {
"ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1,"
"tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0,"
"random_trim_context_prob=0.01,fim_random_seed=42,"
"pack_single=1,pack_complete=0,pack_buffer_size=25,"
"quit_on_epoch=1,seed=42",
"pipeline_name": "local_fim"
},
"train_model_modifiers": [],
"force_enable_checkpointing": True
}
}
1 change: 1 addition & 0 deletions refact_encoding/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, name: str):
self.EOT = self._sentencepiece_tokenizer.eos_id()
self.LF = 13


elif name in ['bigcode_largemodel']:
import tokenizers
filename = Path(__file__).resolve().parent / f"{name}.json"
Expand Down
Loading

0 comments on commit 4d6c35b

Please sign in to comment.