diff --git a/Dockerfile b/Dockerfile index 1048f37c..270981b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,6 +44,10 @@ ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX" COPY . /tmp/app RUN pip install /tmp/app && rm -rf /tmp/app +RUN git clone -b feat/alibi https://github.com/smallcloudai/flash-attention.git /tmp/flash-attention \ + && cd /tmp/flash-attention \ + && MAX_JOBS=8 python3 setup.py install + ENV REFACT_PERM_DIR "/perm_storage" ENV REFACT_TMP_DIR "/tmp" diff --git a/code_contrast/format_2022q3/contrast.py b/code_contrast/format_2022q3/contrast.py index 293bb697..4fbc4c0d 100644 --- a/code_contrast/format_2022q3/contrast.py +++ b/code_contrast/format_2022q3/contrast.py @@ -122,7 +122,7 @@ def from_odm_dict( if tight_shrink: files.reverse() else: - random.shuffle(files) + np_random.shuffle(files) file_poi = defaultdict(set) file_deltokens = defaultdict(list) file_dellines = defaultdict(list) @@ -135,7 +135,7 @@ def from_odm_dict( # dest_lines = odm["dest"][fn].replace('\r\n', '\n').replace('\r', '\n').splitlines() orig_lines = [x+"\n" for x in odm["orig"][fn].splitlines()] dest_lines = [x+"\n" for x in odm["dest"][fn].splitlines()] - if len(orig_lines)==0: + if len(orig_lines) == 0: orig_lines.append("\n") if orig_lines[-1][-1] != "\n": orig_lines[-1] += "\n" @@ -228,7 +228,7 @@ def orig_app(line): opblocks.append(opblock) self.orig_tokens[fn] = orig_all_tokens self.dest_tokens[fn] = dest_all_tokens - random.shuffle(opblocks) + np_random.shuffle(opblocks) raw_ops: List[Tuple[str, str, int, int, int, int]] = list() for opblock in opblocks: raw_ops.extend(opblock) @@ -359,7 +359,7 @@ def app(t, m): self.fn2tstart = dict() self.fn2cut = dict() tpos_unused = list(self.enc.tpos) - random.shuffle(tpos_unused) + np_random.shuffle(tpos_unused) tpos_unused *= 2 need_to_cut_main = 0 need_to_cut_supp = 0 @@ -403,9 +403,9 @@ def app(t, m): else: move_r2 = min(cut_step, cut_more, relax2[fn]) else: - if random.random() < 0.5 and relax1[fn] > 1: + if np_random.random() < 0.5 and relax1[fn] > 1: move_r1 = random.randint(0, min(cut_more, relax1[fn])) - if random.random() < 0.5 and relax2[fn] > 1: + if np_random.random() < 0.5 and relax2[fn] > 1: move_r2 = random.randint(0, min(cut_more, relax2[fn])) assert move_r1 >= 0 and move_r2 >= 0, f"i1={i1} i2={i2} r1={r1} r2={r2}" if SHRINK_DUMP: diff --git a/code_contrast/format_2022q3/contrast_stochastic.py b/code_contrast/format_2022q3/contrast_stochastic.py index e9ac6525..715cdef1 100644 --- a/code_contrast/format_2022q3/contrast_stochastic.py +++ b/code_contrast/format_2022q3/contrast_stochastic.py @@ -89,7 +89,7 @@ def poisson(): for n in range(1, len(result)-1): lop, li1, li2, lj1, lj2 = result[n-1] mop, mi1, mi2, mj1, mj2 = result[n] - if lop == "equal" and mop != "equal" and random.random() < left_prob: + if lop == "equal" and mop != "equal" and np_random.random() < left_prob: assert li2 == mi1 if exact_cx_lines0 >= 0: move = exact_cx_lines0 @@ -104,7 +104,7 @@ def poisson(): mop, mi1, mi2, mj1, mj2 = result[n] rop, ri1, ri2, rj1, rj2 = result[n+1] # if mop != "equal" and rop == "equal" and (random.random() < right_prob or (mi1==mi2 and disable_insert)): - if mop != "equal" and rop == "equal" and random.random() < right_prob: + if mop != "equal" and rop == "equal" and np_random.random() < right_prob: assert ri1 == mi2 if exact_cx_lines1 >= 0: move = exact_cx_lines1 diff --git a/code_contrast/format_2023q2/from_orig_dest_message.py b/code_contrast/format_2023q2/from_orig_dest_message.py index 635987c5..298b79c7 100644 --- a/code_contrast/format_2023q2/from_orig_dest_message.py +++ b/code_contrast/format_2023q2/from_orig_dest_message.py @@ -22,6 +22,7 @@ def from_odm_dict( exact_cx_lines1 = -1, external_poi_ranges: Optional[DefaultDict[str, List[Tuple[int, int]]]] = None, want_cursor_token: bool = False, + random_state: np.random.RandomState = np.random.RandomState(42) ) -> Tuple[Packer, int]: pack = Packer(fmt) files1 = list(odm["orig"].keys()) @@ -33,7 +34,7 @@ def from_odm_dict( # This moves it to the end, more visible to the model fns.reverse() else: - random.shuffle(fns) + random_state.shuffle(fns) files = [] chunks: List[ChunkElement] = [] for fn in fns: @@ -53,7 +54,7 @@ def from_odm_dict( if fn not in odm["dest"]: continue chunks.extend(_run_diff_for_single_file(f, [(x + "\n") for x in odm["dest"][fn].splitlines()], exact_cx_lines0, exact_cx_lines1)) - random.shuffle(chunks) + random_state.shuffle(chunks) for chunk in chunks: pack.add_to_plan(chunk) if want_cursor_token and len(chunks) == 1: @@ -62,9 +63,9 @@ def from_odm_dict( thischunk_lines = set(range(chunks[0].line_n, chunks[0].line_n + len(chunks[0].to_del) + 1)) thischunk_modlines = list(thischunk_lines & modlines) if len(thischunk_modlines) > 0: # Can be zero for whatever reason, cursor appearance is random anyway - aim = random.choice(thischunk_modlines) - shift = np.random.poisson(2) - sign = np.random.choice([-1, 1]) + aim = random_state.choice(thischunk_modlines) + shift = random_state.poisson(2) + sign = random_state.choice([-1, 1]) file0._cursor_token_at_line = aim + shift * sign return pack, msg_plan_n diff --git a/known_models_db/refact_known_models/huggingface.py b/known_models_db/refact_known_models/huggingface.py index eefc0fc5..689cc78c 100644 --- a/known_models_db/refact_known_models/huggingface.py +++ b/known_models_db/refact_known_models/huggingface.py @@ -35,8 +35,8 @@ "diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM", "chat_scratchpad_class": None, "model_class_kwargs": {}, - "required_memory_mb": 6000, - "T": 4096, + "required_memory_mb": 8000, + "T": 8192, "filter_caps": ["completion", "finetune"], }, "starcoder/3b/base": { @@ -45,7 +45,7 @@ "diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM", "chat_scratchpad_class": None, "model_class_kwargs": {}, - "required_memory_mb": 9000, + "required_memory_mb": 12000, "T": 4096, "filter_caps": ["completion", "finetune"], }, @@ -55,8 +55,8 @@ "diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM", "chat_scratchpad_class": None, "model_class_kwargs": {}, - "required_memory_mb": 18000, - "T": 2048, + "required_memory_mb": 20000, + "T": 4096, "filter_caps": ["completion", "finetune"], }, "wizardcoder/15b": { diff --git a/refact_data_pipeline/datautils.py b/refact_data_pipeline/datautils.py index b544c950..6988316d 100644 --- a/refact_data_pipeline/datautils.py +++ b/refact_data_pipeline/datautils.py @@ -1,6 +1,10 @@ +import os + import torch as th from collections import defaultdict -from typing import Iterator, Tuple, Dict, Any, Callable, Sequence +from typing import Iterator, Tuple, Dict, Any, Callable, Iterable, List + +from refact_data_pipeline import DatasetOpts def str2dtype(s: str) -> th.dtype: @@ -14,6 +18,66 @@ def str2dtype(s: str) -> th.dtype: }[s] +_prefer_dtypes = { + "logits": th.int64, + "first": th.bool, + "mask": th.bool +} + + +def _after_collate(result: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]: + if 'first' in result: + result['first'] = result.pop("first")[:, :-1] + if 'mask' in result: + result['mask'] = result.pop("mask")[:, 1:] + result["labels"] = result["tokens"][:, 1:] + result["input"] = result["tokens"][:, :-1] + return { + k: (v if isinstance(v, th.Tensor) else v) + for k, v in result.items() + } + + +def collate_fn(records: List[Dict[str, Any]]) -> Dict[str, Any]: + output = defaultdict(list) + last_stats = None + for idx, record in enumerate(records): + for k, v in record.items(): + if k == "stats": + last_stats = v + continue + output[k].append( + th.tensor(record[k], dtype=_prefer_dtypes.get(k, th.int64)) + ) + return _after_collate({ + "stats": last_stats, + **{k: th.stack(v).contiguous() for k, v in output.items()} + }) + + +def data_parallel_split_and_collate_fn(records: List[Dict[str, Any]]) -> Dict[str, Any]: + rank = int(os.environ.get('RANK', 0)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + output = defaultdict(list) + last_stats = None + for idx, record in enumerate(records): + for k, v in record.items(): + if k == "stats": + last_stats = v + continue + output[k].append( + th.tensor(record[k], dtype=_prefer_dtypes.get(k, th.int64)) + ) + assert len(records) % world_size == 0, "effective batch size %s" % len(records) + effective_bs = len(records) // world_size + from_, to = rank * effective_bs, (rank + 1) * effective_bs + return _after_collate({ + "stats": last_stats, + **{k: th.stack(v)[from_:to].contiguous() for k, v in output.items()} + }) + + def read_and_collate( data_iter: Iterator, prefer_dtypes: Dict[str, str], @@ -58,45 +122,44 @@ def read_and_collate( class BatchIterator: def __init__( self, - seq: Sequence, - dataopts: Dict[str, Any], + inner_filter: Iterable[Any], + dataopts: DatasetOpts ): - self.seq_iter = iter(seq) + self.inner_filter = inner_filter self.dataopts = dataopts self.batch_size = dataopts.get("batch_size", 1) self.device = dataopts.get("device", "cuda") self.drop_last = dataopts.get("drop_last", False) - def __next__(self): - data, datastats = read_and_collate( - data_iter=self.seq_iter, - prefer_dtypes=dict(mask='torch.bool', first='torch.bool'), - B=self.batch_size, - device=self.device, - cold_restart_dict=dict(), - log_stats=True, - progress_callback=lambda *args, **kwargs: None - ) - if len(data) == 0: - raise StopIteration() - - if self.drop_last and len(data['tokens']) < self.batch_size: - raise StopIteration() - - extra = dict() - if 'first' in data: - extra['first'] = data.pop("first")[:, :-1] - if 'mask' in data: - extra['mask'] = data.pop("mask")[:, 1:] - - tokens = data.pop("tokens") - batch = dict( - labels=tokens[:, 1:], - input=tokens[:, :-1], - **extra - ) - batch.update({k: v for k, v in data.items() if k not in batch}) - return batch, datastats - def __iter__(self): - return self \ No newline at end of file + seq_iter = iter(self.inner_filter) + while True: + data, datastats = read_and_collate( + data_iter=seq_iter, + prefer_dtypes=dict(mask='torch.bool', first='torch.bool'), + B=self.batch_size, + device=self.device, + cold_restart_dict=dict(), + log_stats=True, + progress_callback=lambda *args, **kwargs: None + ) + if len(data) == 0: + break + + if self.drop_last and len(data['tokens']) < self.batch_size: + break + + extra = dict() + if 'first' in data: + extra['first'] = data.pop("first")[:, :-1] + if 'mask' in data: + extra['mask'] = data.pop("mask")[:, 1:] + + tokens = data.pop("tokens") + batch = dict( + labels=tokens[:, 1:], + input=tokens[:, :-1], + **extra + ) + batch.update({k: v for k, v in data.items() if k not in batch}) + yield batch, datastats diff --git a/refact_data_pipeline/filters_chat.py b/refact_data_pipeline/filters_chat.py index eef39fff..8460e50f 100644 --- a/refact_data_pipeline/filters_chat.py +++ b/refact_data_pipeline/filters_chat.py @@ -20,12 +20,11 @@ def __init__( self.inner_filter = inner_filter self.n_ctx = dataopts.get("n_ctx", 2048) self.no_format_prob = dataopts.get("chat_no_format_prob", 0.0) - self.chat_random_seed = dataopts.get("chat_random_seed", 42) self.debug = bool(dataopts.get("debug", 0)) self.tkr_stochastic_tokens = bool(dataopts.get("tkr_stochastic_tokens", 0.0)) self.enc: RefactEncoding = dataopts.encoding self.fmt: Format2023q2 = format.format_2023q2_escape(self.enc) - self.random = np.random.RandomState(self.chat_random_seed) + self.random = np.random.RandomState(dataopts.get("seed", 42)) def _pack_format(self, plan: List[MsgElement], odm: Dict, stats: Dict): try: @@ -87,7 +86,10 @@ def _pack_plain(self, plan: List[MsgElement], odm: Dict, stats: Dict): if self.debug: print(f'Chat2023Q2:\n{text}\n\n') - tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens) + else: + tokens = self.enc.encode(text) tokens += [self.enc.EOT] emit = { "tokens": tokens, diff --git a/refact_data_pipeline/filters_diff.py b/refact_data_pipeline/filters_diff.py index b07c40ea..33709e6a 100644 --- a/refact_data_pipeline/filters_diff.py +++ b/refact_data_pipeline/filters_diff.py @@ -1,10 +1,12 @@ +import logging import random import traceback import copy +import numpy as np + from refact_encoding import RefactEncoding from refact_data_pipeline import DatasetOpts -from refact_data_pipeline.finetune import traces from code_contrast.format_2022q3 import contrast from typing import Dict @@ -18,6 +20,8 @@ def __init__(self, self.enc: RefactEncoding = dataopts.encoding self.n_ctx = dataopts.get("n_ctx", 2048) self.selftest = dataopts.get("selftest", 0) + self.random = random.Random(dataopts.get("seed", 42)) + self.np_random = np.random.RandomState(dataopts.get("seed", 42)) def __iter__(self): stats: Dict[str, int] = { @@ -33,7 +37,7 @@ def __iter__(self): if source_files_empty_cnt == len(odm["orig"]): stats["diffskip_onlyadd"] += 1 continue - make_no_changes = random.random() < 0.05 + make_no_changes = self.random.random() < 0.05 if make_no_changes: odm["orig"] = copy.deepcopy(odm["dest"]) if self.selftest: @@ -52,6 +56,7 @@ def __iter__(self): diff.from_odm_dict( odm, n_ctx=self.n_ctx, + np_random=self.np_random ) if len(diff.edits) == 0 and not make_no_changes: stats["diffskip_noedit"] += 1 @@ -73,8 +78,8 @@ def __iter__(self): stats["diffskip_toobig"] += 1 continue except Exception as e: - traces.log(str(odm)) - traces.log(traceback.format_exc()) + logging.error(str(odm)) + logging.error(traceback.format_exc()) stats["diffskip_failed"] += 1 continue edits_within_context = self.n_ctx - diff.offset_edits diff --git a/refact_data_pipeline/filters_diff2023q2.py b/refact_data_pipeline/filters_diff2023q2.py index c8f2a36a..1e506e06 100644 --- a/refact_data_pipeline/filters_diff2023q2.py +++ b/refact_data_pipeline/filters_diff2023q2.py @@ -21,7 +21,7 @@ def __init__(self, self.inner_filter = inner_filter self.n_ctx = dataopts.get("n_ctx", 2048) self.selftest = dataopts.get("selftest", 0) - self.seed = dataopts.get("seed", 0) + self.seed = dataopts.get("seed", 42) self.py_random = random.Random(self.seed if self.seed else None) self.np_random = np.random.RandomState(self.seed if self.seed else None) self.enc: RefactEncoding = dataopts.encoding @@ -59,9 +59,10 @@ def __iter__(self): self.fmt, odm, for_training=True, - exact_cx_lines0 = -1, - exact_cx_lines1 = -1, + exact_cx_lines0=-1, + exact_cx_lines1=-1, want_cursor_token=True, + random_state=self.np_random ) if len(pack.plan) - 1 == msg_plan_n and not make_no_changes: stats["diffskip_noedit"] += 1 @@ -72,8 +73,8 @@ def __iter__(self): limit_ctx_n=self.n_ctx, limit_aux_n=0, add_eot=True, - for_training=True, - ) + for_training=True + ) # edits_made = len(pack.plan) - 1 - msg_plan_n # print("edits: %i" % edits_made) # if edits_made == 1: diff --git a/refact_data_pipeline/filters_fim.py b/refact_data_pipeline/filters_fim.py index 983ca2e1..b8d605a3 100644 --- a/refact_data_pipeline/filters_fim.py +++ b/refact_data_pipeline/filters_fim.py @@ -9,9 +9,11 @@ class SymbolsMiddleSplit: def __init__(self, + random_state, min_symbols: int = 1, max_symbols: int = 4000, # 4k is one dense screen of text ): + self.random = random_state self._min_symbols = min_symbols self._max_symbols = max_symbols @@ -22,9 +24,9 @@ def split(self, text: str): ]) if self._min_symbols > max_symbols: raise RuntimeError - mid_symbols = random.randint(self._min_symbols, max_symbols) + mid_symbols = self.random.randint(self._min_symbols, max_symbols) assert len(text) - mid_symbols - 1 >= 0 - split_pos = random.randint(0, len(text) - mid_symbols - 1) + split_pos = self.random.randint(0, len(text) - mid_symbols - 1) middle = text[split_pos:split_pos + mid_symbols] assert len(middle) == mid_symbols @@ -47,6 +49,8 @@ def __init__( self.fim_probability = dataopts.get("fim_probability", 0.5) self.tkr_stochastic_tokens = dataopts.get("tkr_stochastic_tokens", 3) self.enc: RefactEncoding = dataopts.encoding + if hasattr(self.enc, "set_random_seed"): + self.enc.set_random_seed(dataopts.get("seed", 42)) self.special_tokens = [ self.enc.PREFIX, self.enc.SUFFIX, @@ -54,7 +58,8 @@ def __init__( self.enc.EOT, ] assert len(set(self.special_tokens)) == len(self.special_tokens) - self.splitter = SymbolsMiddleSplit() + self.random = random.Random(dataopts.get("seed", 42)) + self.splitter = SymbolsMiddleSplit(self.random) def __iter__(self): stats: Dict[str, Union[int, float]] = { @@ -63,10 +68,13 @@ def __iter__(self): "fim_out": 0, } for sample in self.inner_filter: - tokens, _ = self.enc.encode_stochastic(sample["text"], [], 0.01*self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + tokens, _ = self.enc.encode_stochastic(sample["text"], [], 0.01 * self.tkr_stochastic_tokens) + else: + tokens = self.enc.encode(sample["text"]) cursor = 0 while cursor < len(tokens): - if random.random() > self.fim_probability: + if self.random.random() > self.fim_probability: # plain text branch plain = tokens[cursor : cursor + self.n_ctx] cursor += len(plain) @@ -88,15 +96,18 @@ def __iter__(self): } else: # FIM - wiggle_low = (self.n_ctx * 9 // 20) if random.randint(0, 2) == 0 else (self.n_ctx * 18 // 20) - wiggle = random.randint(wiggle_low, self.n_ctx * 21 // 20) + wiggle_low = (self.n_ctx * 9 // 20) if self.random.randint(0, 2) == 0 else (self.n_ctx * 18 // 20) + wiggle = self.random.randint(wiggle_low, self.n_ctx * 21 // 20) # n_ctx *9//20 *18//20 *21//20 # 4096 -> 2048 3686 4300 # 2048 -> 1024 1843 2150 pre_fim_toks = tokens[cursor : cursor + wiggle] cursor += len(pre_fim_toks) try: - text = self.enc.decode_utf8(pre_fim_toks) + if hasattr(self.enc, 'decode_utf8'): + text = self.enc.decode_utf8(pre_fim_toks) + else: + text = self.enc.decode(pre_fim_toks) except: stats["fim_unicode_split"] += 1 continue @@ -108,15 +119,23 @@ def __iter__(self): except (RuntimeError, ValueError): stats["fim_unable_to_split"] += 1 continue - prefix_toks, _ = self.enc.encode_stochastic(prefix, [], 0.01*self.tkr_stochastic_tokens) - suffix_toks, _ = self.enc.encode_stochastic(suffix, [], 0.01*self.tkr_stochastic_tokens) - if random.random() < 0.5: + if hasattr(self.enc, 'encode_stochastic'): + prefix_toks, _ = self.enc.encode_stochastic(prefix, [], 0.01 * self.tkr_stochastic_tokens) + suffix_toks, _ = self.enc.encode_stochastic(suffix, [], 0.01 * self.tkr_stochastic_tokens) + else: + prefix_toks = self.enc.encode(prefix) + suffix_toks = self.enc.encode(suffix) + + if self.random.random() < 0.5: tokens_context = [self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks mask_context = [0] + [1] * len(prefix_toks) + [0] + [1] * len(suffix_toks) else: tokens_context = [self.enc.SUFFIX] + suffix_toks + [self.enc.PREFIX] + prefix_toks mask_context = [0] + [1] * len(suffix_toks) + [0] + [1] * len(prefix_toks) - middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01*self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01 * self.tkr_stochastic_tokens) + else: + middle_toks = self.enc.encode(middle) middle_mask = [1] * len(middle_toks) yield { "tokens": tokens_context + [self.enc.INFILL] + middle_toks + [self.enc.EOT], diff --git a/refact_data_pipeline/filters_fim_v2.py b/refact_data_pipeline/filters_fim_v2.py index 48b0dbe1..4adbaa47 100644 --- a/refact_data_pipeline/filters_fim_v2.py +++ b/refact_data_pipeline/filters_fim_v2.py @@ -213,11 +213,12 @@ def __init__( self.n_ctx = dataopts.get("n_ctx", 2048) self.fim_probability = dataopts.get("fim_probability", 0.5) self.tkr_stochastic_tokens = dataopts.get("tkr_stochastic_tokens", 3) - self.fim_random_seed = dataopts.get("fim_random_seed", 42) self.fim_drop_residuals = bool(dataopts.get("fim_drop_residuals", 0)) self.random_trim_context_prob = bool(dataopts.get("random_trim_context_prob", 0.0)) self.debug = bool(dataopts.get("debug", 0)) self.enc = dataopts.encoding + if hasattr(self.enc, "set_random_seed"): + self.enc.set_random_seed(dataopts.get("seed", 42)) self.special_tokens = [ self.enc.PREFIX, self.enc.SUFFIX, @@ -225,7 +226,7 @@ def __init__( self.enc.EOT, ] assert len(set(self.special_tokens)) == len(self.special_tokens) - self.random = np.random.RandomState(self.fim_random_seed) + self.random = np.random.RandomState(dataopts.get("seed", 42)) self.splitters_probs = [ (InsideSingleRow(random=self.random), 0.2), (MiddleToEndSingleRow(random=self.random), 0.399), @@ -245,7 +246,10 @@ def __iter__(self): text = sample["text"] if self.random.random() < self.random_trim_context_prob: text = _random_trim_context(text, self.random) - tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens) + else: + tokens = self.enc.encode(text) cursor = 0 while cursor < len(tokens): if self.random.random() > self.fim_probability: @@ -287,7 +291,10 @@ def _generate_fim(self, tokens, cursor, sample, stats) \ cursor += len(pre_fim_toks) is_cut_file = len(tokens[cursor:]) > 0 try: - text = self.enc.decode_utf8(pre_fim_toks) + if hasattr(self.enc, 'decode_utf8'): + text = self.enc.decode_utf8(pre_fim_toks) + else: + text = self.enc.decode(pre_fim_toks) except: stats["fim_unicode_split"] += 1 return None, cursor @@ -306,15 +313,22 @@ def _generate_fim(self, tokens, cursor, sample, stats) \ stats["fim_unable_to_split"] += 1 return None, cursor - prefix_toks, _ = self.enc.encode_stochastic(prefix, [], 0.01 * self.tkr_stochastic_tokens) - suffix_toks, _ = self.enc.encode_stochastic(suffix, [], 0.01 * self.tkr_stochastic_tokens) - if random.random() < 0.5: + if hasattr(self.enc, 'encode_stochastic'): + prefix_toks, _ = self.enc.encode_stochastic(prefix, [], 0.01 * self.tkr_stochastic_tokens) + suffix_toks, _ = self.enc.encode_stochastic(suffix, [], 0.01 * self.tkr_stochastic_tokens) + else: + prefix_toks = self.enc.encode(prefix) + suffix_toks = self.enc.encode(suffix) + if self.random.random() < 0.5: tokens_context = [self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks mask_context = [0] + [1] * len(prefix_toks) + [0] + [1] * len(suffix_toks) else: tokens_context = [self.enc.SUFFIX] + suffix_toks + [self.enc.PREFIX] + prefix_toks mask_context = [0] + [1] * len(suffix_toks) + [0] + [1] * len(prefix_toks) - middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01 * self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01 * self.tkr_stochastic_tokens) + else: + middle_toks = self.enc.encode(middle) middle_mask = [1] * len(middle_toks) stats["fim_out"] += 1 if self.debug: diff --git a/refact_data_pipeline/filters_hdfs.py b/refact_data_pipeline/filters_hdfs.py index 24e34f4d..312d8aae 100644 --- a/refact_data_pipeline/filters_hdfs.py +++ b/refact_data_pipeline/filters_hdfs.py @@ -1,4 +1,3 @@ -import random from pathlib import Path from typing import Tuple, Optional, Any, List, Dict @@ -17,7 +16,6 @@ def _try_open(path: Path) -> Optional[Any]: return None - class Hdf5Dataset: """ A class that maps HDF5 files to flat array of data @@ -41,12 +39,9 @@ def __init__( self.files = files self.tables = [file.root.data for file in self.files] self.keys = dataopts.get("keys", "tokens;mask").split(';') - self.manual_seed = dataopts.get("hdfs_seed", None) + self.seed = dataopts.get("seed", 42) self.comm = comm self.cold_restart_skip = cold_restart_skip - if self.cold_restart_skip is not None: - assert self.manual_seed is not None, \ - "`cold_restart_skip` requires the manual seed, otherwise it doesn't make sence" self.tables_lengths = [len(t) for t in self.tables] self.tables_lengths_cumsum = np.cumsum(self.tables_lengths) self.overall_length = self.tables_lengths_cumsum[-1] @@ -58,14 +53,7 @@ def __del__(self): file.close() def __reshuffle(self) -> np.ndarray: - if self.manual_seed is None: - seed = random.randint(0, 2 ** 32 - 1) - if self.comm is not None: - seed = self.comm.bcast(seed, root=0) - else: - seed = self.manual_seed - - rng = np.random.default_rng(seed) + rng = np.random.default_rng(self.seed) index = rng.choice(self.overall_length, self.overall_length, replace=False) if self.comm is not None: diff --git a/refact_data_pipeline/filters_packing.py b/refact_data_pipeline/filters_packing.py index a67fc0e8..7009e3fb 100644 --- a/refact_data_pipeline/filters_packing.py +++ b/refact_data_pipeline/filters_packing.py @@ -143,6 +143,8 @@ def __init__( self.inner_filter_iter = iter(inner_filter) self.enc = dataopts.encoding self.n_ctx: int = dataopts['n_ctx'] + self.random = random.Random(dataopts.get('seed', 42)) + self.np_random = np.random.RandomState(dataopts.get('seed', 42)) self.pack_single: bool = dataopts.get('pack_single', 0) == 1 self.pack_complete: bool = dataopts.get('pack_complete', 1) == 1 self.drop_less_than_t: int = dataopts.get('pack_drop_less_than_t', 6) @@ -224,7 +226,7 @@ def _pop_item_by_length(length: int) -> ItemT: return [] if force_random_get or not self.pack_complete: - item = self.buffer.pop(random.randint(0, len(self.buffer) - 1)) + item = self.buffer.pop(self.random.randint(0, len(self.buffer) - 1)) return [item] else: lengths = [self.__item_len(i) for i in self.buffer] @@ -238,7 +240,7 @@ def _pop_item_by_length(length: int) -> ItemT: # prioritize items with larger lengths p = softmax(np.exp(np.array([sum(b) for b in bins]) / budget * 2)) - bin = bins[np.random.choice(list(range(len(bins))), p=p)] + bin = bins[self.np_random.choice(list(range(len(bins))), p=p)] items = [_pop_item_by_length(l) for l in bin] return items @@ -250,7 +252,7 @@ def __merge_items( assert len(items_acc) > 0 if random_order: - np.random.shuffle(items_acc) + self.random.shuffle(items_acc) last_item = items_acc[-1] if self.__items_len(items_acc) < self.n_ctx: items_acc.append(self.__make_padded_item(self.n_ctx - self.__items_len(items_acc))) diff --git a/refact_data_pipeline/finetune/finetune_filter.py b/refact_data_pipeline/finetune/finetune_filter.py deleted file mode 100644 index 55dc5b62..00000000 --- a/refact_data_pipeline/finetune/finetune_filter.py +++ /dev/null @@ -1,322 +0,0 @@ -import math -import os -import time -import json -import random - -import jsonlines -import textwrap -import sys -import signal -import logging -import torch as th - -from functools import partial -from torchinfo import summary - -import refact_data_pipeline.finetune.traces as traces -from refact_data_pipeline import DatasetOpts, finetune_datasource -from refact_data_pipeline.datautils import BatchIterator -from refact_data_pipeline.finetune.finetune_utils import get_finetune_config -from refact_data_pipeline.finetune.finetune_utils import get_finetune_filter_stat -from refact_data_pipeline.finetune.finetune_filtering_defaults import finetune_filtering_defaults -from refact_data_pipeline.finetune.finetune_config import base_config -from refact_data_pipeline.finetune.model_handling import make_model, masked_loss, model_forward -from refact_data_pipeline.finetune.process_uploaded_files import make_matcher -from self_hosting_machinery import env - -from typing import List, Dict, Any - - -unfiltered_train = os.path.join(env.DIR_UNPACKED, "train_set.jsonl") -unfiltered_test = os.path.join(env.DIR_UNPACKED, "test_set.jsonl") - -filtered_train = os.path.join(env.DIR_UNPACKED, "train_set_filtered.jsonl") -filtered_test = os.path.join(env.DIR_UNPACKED, "test_set_filtered.jsonl") - - -def _update_and_dump_status(stats_dict: Dict[str, Any], new_status): - traces.touch() - env.report_status("filter", new_status) - stats_dict["filterting_status"] = new_status - with open(env.CONFIG_FINETUNE_FILTER_STAT + ".tmp", "w") as f: - json.dump(stats_dict, f, indent=4) - os.rename(env.CONFIG_FINETUNE_FILTER_STAT + ".tmp", env.CONFIG_FINETUNE_FILTER_STAT) - return stats_dict - - -def _file_accepted(reason, path): - with open(env.LOG_FILES_ACCEPTED_FTF, "a", encoding="utf-8") as f: - f.write("%s %s\n" % (reason, path)) - - -def _file_rejected(reason, path): - with open(env.LOG_FILES_REJECTED_FTF, "a", encoding="utf-8") as f: - f.write("%s %s\n" % (reason, path)) - - -def get_force_included_excluded_matchers(): - fcfg = { - "filetypes_finetune": {}, - "filetypes_db": {} - } - if os.path.exists(env.CONFIG_HOW_TO_FILETYPES): - traces.log("Reading %s" % env.CONFIG_HOW_TO_FILETYPES) - with open(env.CONFIG_HOW_TO_FILETYPES, "r") as f: - fcfg.update(**json.load(f)) - - force_include_matcher, _ = make_matcher(fcfg.get('force_include', '')) - force_exclude_matcher, _ = make_matcher(fcfg.get('force_exclude', '')) - - return force_include_matcher, force_exclude_matcher - - -@th.inference_mode() -def loss_based_filter( - train_files: List, - model, - loss_function, - dataopts, - *, - fcfg, - stats_dict, - cfg, -): - t0 = time.time() - iter_times = [] - model.eval() - batch_iter_fn = partial(BatchIterator, dataopts=dict(batch_size=1, drop_last=False)) - all_losses, rejected = [], set() - stats_dict['total_steps'] = len(train_files) - is_force_included, is_force_excluded = get_force_included_excluded_matchers() - forward = partial(model_forward, model=model, low_gpu_mem_mode=False, backend=cfg['model_info']['backend']) - for iter_n, file in enumerate(train_files): - t0_iter = time.time() - stats_dict = _update_and_dump_status(stats_dict, "working") - file_losses = [] - if is_force_included(file['path']): - _file_accepted("FILTER1 INCLUDED_BY_MASK", file["path"]) - stats_dict["accepted"] += 1 - continue - elif is_force_excluded(file['path']): - traces.log("REJECTED FILTER %-100s EXCLUDED_BY_MASK" % file["path"]) - rejected.add(file["path"]) - _file_rejected("FILTER1 EXCLUDED_BY_MASK", file["path"]) - stats_dict["rejected"] += 1 - continue - - file_iter = iter(batch_iter_fn(finetune_datasource.local_plain([file], dataopts))) - exception = None - while True: - try: - batch, stats = next(file_iter) - except StopIteration: - break - except Exception as e: - exception = e - break - logits = forward(input=batch['input']) - loss = float(loss_function( - logits=logits.to(th.float32), - labels=batch['labels'], - mask=batch['mask'], - ).item()) - if math.isnan(loss) or math.isinf(loss): - traces.log(f"Skipping invalid loss={loss:.2f} value in file {file['path']}") - else: - file_losses.append(loss) - - if exception is not None: - traces.log("REJECTED FILTER %-100s %s" % (file["path"], str(exception))) - rejected.add(file["path"]) - _file_rejected(f"FILTER1 {exception}", file["path"]) - stats_dict["rejected"] += 1 - continue - - if len(file_losses) == 0: - traces.log("REJECTED FILTER %-100s empty" % file["path"]) - rejected.add(file["path"]) - _file_rejected("FILTER1 EMPTY", file["path"]) - stats_dict["rejected"] += 1 - continue - - file_loss = sum(file_losses) / len(file_losses) - - if file_loss > fcfg['filter_loss_threshold']: - traces.log("REJECTED FILTER %-100s loss %0.3f" % (file["path"], file_loss)) - rejected.add(file["path"]) - _file_rejected("FILTER1 %0.3f" % file_loss, file["path"]) - stats_dict["rejected"] += 1 - else: - _file_accepted("LOSS %0.3f" % file_loss, file["path"]) - stats_dict["accepted"] += 1 - all_losses.append(file_loss) - stats_dict['avg_loss'] = sum(all_losses) / len(all_losses) - - iter_times.append(time.time() - t0_iter) - eta = (len(train_files) - iter_n) * (sum(iter_times) / len(iter_times)) - stats_dict["eta_minutes"] = int(round(eta / 60)) - stats_dict["worked_steps"] = iter_n - stats_dict["worked_minutes"] = int((time.time() - t0) / 60) - - traces.log("calculated frames %i " % len(train_files)) - traces.log("avg loss %0.4f" % stats_dict['avg_loss']) - - return rejected - - -def pre_filtering(stats_dict, models_db: Dict[str, Any]): - finetune_cfg = get_finetune_config(models_db, logger=traces.log) - - fcfg = {**finetune_filtering_defaults} - if os.path.exists(env.CONFIG_HOW_TO_FILTER): - traces.log("Reading %s" % env.CONFIG_HOW_TO_FILTER) - fcfg.update(**json.load(open(env.CONFIG_HOW_TO_FILTER))) - - has_train_files = os.path.exists(os.path.join(env.DIR_UNPACKED, unfiltered_train)) and \ - len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, unfiltered_train)))) - if not has_train_files: - raise RuntimeError("No train files have been provided for filtering") - - logging.info("Train set filtering, loading model...") - traces.log("Train set filtering, loading model...") - t0 = time.time() - cfg = base_config(finetune_cfg["model_name"], models_db) - model = make_model( - model_name=finetune_cfg["model_name"], - weights_path=cfg['model_info']['weight_path'], - repo_id=cfg['model_info']['repo_id'], - backend=cfg['model_info']['backend'], - freeze_exceptions=cfg['model_info']['freeze_exceptions'], - lora_target_modules=cfg['model_info']['lora']['lora_target_modules'], - lora_r=cfg['model_info']['lora']['lora_r'], - lora_alpha=cfg['model_info']['lora']['lora_alpha'], - lora_dropout=0, - lora_init_scale=1e-5, - dtype=th.bfloat16 if 'bf16' in cfg and cfg['bf16']['enabled'] else th.float16, - init_device="cuda", - device="cuda", - ) - t1 = time.time() - logging.info("/model load %0.1fms" % ((t1 - t0) * 1000)) - model.train() - - if fcfg["debug"]: - logging.info("1 gpumem_p0 %0.2fG" % (th.cuda.max_memory_allocated() / 1e9)) - summary(model, depth=4, col_names=['num_params', 'params_percent', 'trainable']) - - dataopts = DatasetOpts("n_ctx=%d,pack_at_most=1,quit_on_epoch=1,seed=42" % (cfg['model_info']['ctx_size'] + 1)) - dataopts.set_encoding(model.encoding) - train_files = list(jsonlines.open(unfiltered_train)) - train_files = train_files[:fcfg["limit_train_files"]] - loss_function = partial( - masked_loss, average_elements=cfg['model_info']['loss_average_elements'], enc=model.encoding - ) - - test_files = list(jsonlines.open(unfiltered_test)) - if len(test_files) > fcfg["limit_test_files"]: - traces.log(f"Manually selected test set contains {len(test_files)} files, " - f"more than allowed {fcfg['limit_test_files']}.\n" - f"It could heavily slow down the training process") - - text = "FILTER explanation: initial loss too big calculated on a single file, threshold is %0.3f. " \ - "Likely means the file doesn't contain code." % fcfg["filter_loss_threshold"] - traces.log(textwrap.fill(text, width=100)) - - filtered = loss_based_filter( - train_files, model, loss_function, dataopts, fcfg=fcfg, stats_dict=stats_dict, - cfg=cfg - ) - - test_filenames = set() - if len(test_files) == 0: - test_files_count = min(fcfg["limit_test_files"], len(train_files) // 2) - if test_files_count == 0: - traces.log("Warning: It is too little files to choose a test set from. " - "It's strongly recommended to choose a test set manually to be able to prevent overfitting") - else: - test_files = random.choices(train_files, k=fcfg["limit_test_files"]) - test_filenames.update([p['path'] for p in test_files]) - - with open(filtered_train, "w") as f: - for fdict in train_files: - p = fdict["path"] - rejected_by_filters = p in filtered - included_in_test_set = p in test_filenames - if rejected_by_filters or included_in_test_set: - continue - f.write(json.dumps(fdict) + "\n") - - traces.log("-" * 40 + "TEST SET" + "-" * 40) - with open(filtered_test, "w") as f: - for fdict in test_files: - traces.log("test set file: %s" % (fdict["path"])) - f.write(json.dumps(fdict) + "\n") - - -def needs_any_work(): - try: - has_updates = [os.path.getmtime(unfiltered_train) > os.path.getmtime(filtered_train), - os.path.getmtime(unfiltered_test) > os.path.getmtime(filtered_test)] - if os.path.exists(env.CONFIG_HOW_TO_FILTER): - has_updates.append(os.path.getmtime(env.CONFIG_HOW_TO_FILTER) > os.path.getmtime(filtered_train)) - if os.path.exists(env.CONFIG_HOW_TO_FILETYPES): - has_updates.append(os.path.getmtime(env.CONFIG_HOW_TO_FILETYPES) > os.path.getmtime(filtered_train)) - except OSError: - return True - return any(has_updates) - - -def main(models_db: Dict[str, Any]): - stats_dict = get_finetune_filter_stat(default=True) - - def catch_sigusr1(signum, frame): - logging.info("catched SIGUSR1, interrupted") - stats_dict["error"] = "interrupted" - _update_and_dump_status(stats_dict, "interrupted") # saves whatever numbers reached so far - exit(99) - - signal.signal(signal.SIGUSR1, catch_sigusr1) - - if not needs_any_work(): - logging.info("Train set filtering: nothing changed since last time, quit") - return - - stats_dict = _update_and_dump_status(stats_dict, "starting") # writes zeros - with open(env.LOG_FILES_ACCEPTED_FTF, "w") as f: - f.write("") - with open(env.LOG_FILES_REJECTED_FTF, "w") as f: - f.write("") - try: - pre_filtering(stats_dict, models_db) - _update_and_dump_status(stats_dict, "finished") - except SystemExit: - # catched sigusr1, interrupt by watchdog - exit(99) # this has to be there, even if catch_sigusr1() already called exit with 99, otherwise exit code is zero - except KeyboardInterrupt: - # interrupt by user - # 99 is code for interrupted - exit(99) - except Exception as e: - if traces.context(): - logging.error("FAILED finetune filter at %s" % traces.context().path) - if "error" not in stats_dict: # if there is, a more detailed error is already in place - t = str(e) or str(type(e)) - stats_dict["error"] = t - logging.error("FAILED: %s" % t) - traces.log("FAILED: %s" % t) - _update_and_dump_status(stats_dict, "failed") - exit(1) - if isinstance(e, ValueError): # don't print stack for ValueError which is used for mundane data problems - exit(1) - raise e - # finetune_sequence relies on exit code to continue or stop - - -if __name__ == "__main__": - from known_models_db.refact_known_models import models_mini_db - - YMD_hms = os.environ.get("LORA_LOGDIR", "") or time.strftime("lora-%Y%m%d-%H%M%S") - traces.configure(task_dir="loras", task_name=YMD_hms, work_dir=env.PERMDIR) - - main(models_mini_db) diff --git a/refact_data_pipeline/finetune/finetune_filtering_defaults.py b/refact_data_pipeline/finetune/finetune_filtering_defaults.py deleted file mode 100644 index 52544bfd..00000000 --- a/refact_data_pipeline/finetune/finetune_filtering_defaults.py +++ /dev/null @@ -1,8 +0,0 @@ -finetune_filtering_defaults = { - "limit_train_files": 1000000, - "limit_test_files": 5, - "filter_loss_threshold": 3.5, - "filter_gradcosine_threshold": 0.1, - "low_gpu_mem_mode": True, - "debug": False -} diff --git a/refact_data_pipeline/finetune/finetune_train.py b/refact_data_pipeline/finetune/finetune_train.py deleted file mode 100644 index a71fa48b..00000000 --- a/refact_data_pipeline/finetune/finetune_train.py +++ /dev/null @@ -1,397 +0,0 @@ -import os -import time -import json -import subprocess -import sys -import signal - -import deepspeed -import logging -import torch as th - -from functools import partial -from pathlib import Path -from jsonlines import jsonlines -from torchinfo import summary - -from refact_data_pipeline.finetune import traces, supported_models -from refact_data_pipeline import DatasetOpts, finetune_datasource -from refact_data_pipeline.datautils import BatchIterator -from refact_data_pipeline.finetune.finetune_config import base_config, ConfigBuilder -from refact_data_pipeline.finetune.finetune_utils import get_finetune_config -from refact_data_pipeline.finetune.model_handling import make_model, masked_loss, save_model_state, model_forward, \ - setup_encoding -from self_hosting_machinery import env - -from typing import Optional, Callable, Dict, Any, Tuple - - -filtered_train = "train_set_filtered.jsonl" -filtered_test = "test_set_filtered.jsonl" - - -class EarlyStopper: - def __init__(self, patience=1, min_delta=0): - self.patience = patience - self.min_delta = min_delta - self.counter = 0 - self.min_validation_loss = float('inf') - - def __call__(self, validation_loss): - if validation_loss < self.min_validation_loss: - self.min_validation_loss = validation_loss - self.counter = 0 - elif validation_loss > (self.min_validation_loss + self.min_delta): - self.counter += 1 - if self.counter >= self.patience: - return True - return False - - -def save_status_json(status_dict, status_string): - # FIXME: rank == 0 - rank = 0 - if rank != 0: - return - traces.touch() - env.report_status("ftune", status_string) - status_dict["status"] = status_string - if not traces.context(): - return - try: - with open(os.path.join(traces.context().path, "status.json.tmp"), "w") as f: - json.dump(status_dict, f, indent=4) - os.rename(os.path.join(traces.context().path, "status.json.tmp"), - os.path.join(traces.context().path, "status.json")) - except Exception as e: - traces.log("ERROR SAVING STATS: %s" % (e or str(type(e)))) - traces.log("(no big deal, will try again next iteration)") - - -def load_finetune_config(models_db: Dict[str, Any]) -> Dict[str, Any]: - def _get_ds_len_per_epoch(model_name, cfg_builder): - model_config = supported_models.config[model_name] - ds_opts = DatasetOpts(model_config["train_ds_pipeline"]["ds_opts"].format( - n_ctx=cfg_builder.cfg['model_info']['ctx_size'] + 1 - ) + ",quit_on_epoch=1") - ds_opts.set_encoding(setup_encoding( - model_name=model_name, - weights_path=cfg_builder.cfg['model_info']['weight_path'], - repo_id=cfg_builder.cfg['model_info']['repo_id'] - )) - pipe = getattr(finetune_datasource, model_config["train_ds_pipeline"]["pipeline_name"]) - ds = pipe(filtered_train, ds_opts) - ds_len = 0 - try: - for _ in ds: - ds_len += 1 - return ds_len - except Exception as e: - return ds_len - - with open(env.CONFIG_FINETUNE_FILTER_STAT, 'r') as f: - initial_loss = json.load(f)["avg_loss"] - - user_cfg = get_finetune_config(models_db, logger=traces.log) - cfg_builder = ConfigBuilder(base_config(user_cfg['model_name'], models_db)) - if user_cfg['use_heuristics']: - traces.log("Retrieving dataset length per epoch, it may take a while...") - ds_len = _get_ds_len_per_epoch(user_cfg['model_name'], cfg_builder) - traces.log(f"Dataset length per epoch = {ds_len}") - (cfg_builder - .set_lora_quality_by_heuristics(ds_len=ds_len, initial_loss=initial_loss) - .set_schedule_by_heuristics(ds_len=ds_len) - .set_low_gpu_mem_mode_by_heuristics()) - else: - (cfg_builder - .set_train_steps(user_cfg['train_steps']) - .set_lr_decay_steps(user_cfg['lr_decay_steps']) - .set_lora_r(user_cfg['lora_r']) - .set_lora_alpha(user_cfg['lora_alpha']) - .set_lora_init_scale(user_cfg['lora_init_scale']) - .set_lora_dropout(user_cfg['lora_dropout']) - .set_low_gpu_mem_mode(user_cfg['low_gpu_mem_mode'])) - (cfg_builder - .set_lr(user_cfg['lr']) - .set_batch_size(user_cfg['batch_size']) - .set_warmup_steps(user_cfg['warmup_num_steps']) - .set_limit_time_seconds(user_cfg['limit_time_seconds']) - .set_weight_decay(user_cfg['weight_decay'])) - - traces.log(f'Freeze exceptions: {cfg_builder.cfg["model_info"]["freeze_exceptions"]}') - for k, v in cfg_builder.cfg["model_info"]["lora"].items(): - traces.log(f'Lora config: {k:>20} {v}') - - with open(os.path.join(traces.context().path, "config.json"), "w") as f: - json.dump(cfg_builder.cfg, f, indent=4) - - return cfg_builder.cfg - - -def create_data(model_name, cfg, enc) -> Tuple[Any, Optional[Any]]: - model_config = supported_models.config[model_name] - train_dataopts = DatasetOpts(model_config["train_ds_pipeline"]["ds_opts"].format( - n_ctx=cfg['model_info']['ctx_size'] + 1 - )) - train_dataopts.set_encoding(enc) - test_dataopts = DatasetOpts(model_config["test_ds_pipeline"]["ds_opts"].format( - n_ctx=cfg['model_info']['ctx_size'] + 1 - )) - test_dataopts.set_encoding(enc) - - train_pipe = getattr(finetune_datasource, model_config["train_ds_pipeline"]["pipeline_name"]) - test_pipe = getattr(finetune_datasource, model_config["test_ds_pipeline"]["pipeline_name"]) - - train_ds = train_pipe(filtered_train, train_dataopts) - train_ds = BatchIterator(train_ds, dataopts=dict( - batch_size=cfg['train_batch_size'], - drop_last=True - )) - has_train_files = os.path.exists(os.path.join(env.DIR_UNPACKED, filtered_train)) and \ - len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, filtered_train)))) > 0 - if not has_train_files: - raise RuntimeError("No train files provided") - - has_test_files = os.path.exists(os.path.join(env.DIR_UNPACKED, filtered_test)) \ - and len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, filtered_test)))) > 0 - if has_test_files: - test_ds = test_pipe(filtered_test, test_dataopts) - test_ds = list(test_ds) - else: - traces.log("Warning: no test set provided, the number of files is zero") - test_ds = None - return train_ds, test_ds - - -def loop( - cfg, - model, - optimizer, - loss_function: Callable, - model_name: str, - *, - status_dict, - train_ds, - test_ds: Optional[Any] -): - def _save_checkpoint(force: bool = False): - if force or (iter_n != 0 and iter_n % cfg['save_every'] == 0): - if "test_loss" in progress: - tag = "iter%04d-testloss%0.3f" % (iter_n, progress["test_loss"]) - else: - tag = "iter%04d-trainloss%0.3f" % (iter_n, progress["loss"]) - traces.log("saving checkpoint %s" % tag) - save_model_state(model, save_path=save_path, tag=tag) - - model_config = supported_models.config[model_name] - save_path = os.path.join(traces.context().path, "checkpoints") - model.train() - test_ds_fn = partial(BatchIterator, dataopts=dict( - batch_size=1, - drop_last=False - )) - micro_bs = cfg['micro_batch_size'] - backend = cfg['model_info']['backend'] - tokens_n = 0 - iter_time_last = None - t0 = time.time() - # Each checkpoint must be tested: - assert cfg['train_iters'] % cfg['test_every'] == 0 - assert cfg['save_every'] % cfg['test_every'] == 0 - plot_process: Optional[subprocess.Popen] = None - save_status_json(status_dict, "working") - low_gpu_mem_mode = cfg['low_gpu_mem_mode'] or model_config['force_enable_checkpointing'] - forward = partial(model_forward, model=model, backend=backend) - early_stop = EarlyStopper(patience=int(cfg['train_iters'] * 0.2)) - for iter_n in range(cfg['train_iters'] + 1): # +1 so we can save 100 (not 99) - t0_iter = time.time() - traces.progress("iteration", iter_n) - data = next(train_ds, None) - if data is None: - break - batch, ds_stats = data - - if cfg['debug']: - data_path = Path(traces.context().path) / ('debug_data/iter%04d' % iter_n) - data_path.mkdir(exist_ok=True, parents=True) - traces.log( - f"iter {iter_n}/{cfg['train_iters']} tokens {tokens_n / 1e9:0.3f} " - f"input={traces.p(batch['input'])} mask={traces.p(batch['mask'])} " - f"({batch['mask'].sum()}/{batch['mask'].numel()})" - ) - - for b0 in range(0, cfg.get("train_batch_size"), cfg.get("micro_batch_size")): - try: - input = batch['input'][b0:b0 + micro_bs].contiguous() - logits = forward(input=input, low_gpu_mem_mode=low_gpu_mem_mode) - loss = loss_function( - logits=logits, - labels=batch['labels'][b0:b0 + micro_bs].contiguous(), - mask=batch['mask'][b0:b0 + micro_bs].contiguous(), - ) - model.backward(loss) - except th.cuda.OutOfMemoryError as e: - if low_gpu_mem_mode: - raise e - else: - model.optimizer.zero_grad() - th.cuda.empty_cache() - low_gpu_mem_mode = True - traces.log("switching to low GPU memory mode") - continue - - model.step() - tokens_n += input.shape[0] * input.shape[1] - traces.progress('loss', loss) - - if cfg['debug']: - with open(data_path / ('%d_%0.3f.txt' % (b0, loss.item())), 'w') as f: - f.write(model.encoding.decode(input[0].cpu().numpy())) - - if test_ds is not None and cfg["test_every"] > 0 and iter_n % cfg["test_every"] == 0: - model.eval() - with th.inference_mode(): - test_losses = [] - for batch, _ in test_ds_fn(test_ds): - logits = forward(input=batch['input'], low_gpu_mem_mode=low_gpu_mem_mode) - test_loss = loss_function( - logits=logits, - labels=batch['labels'], - mask=batch['mask'], - ) - traces.progress('test_loss', test_loss) - test_losses.append(test_loss) - if len(test_losses) > 0 and early_stop(sum(test_losses) / len(test_losses)): - traces.log(f"Stopping the training due to " - f"test loss was above minimum {early_stop.counter} times") - _save_checkpoint(force=True) - break - model.train() - - for k, v in ds_stats.items(): - traces.progress(f'ds/{k}', v) - traces.progress("gtokens", tokens_n / 1e9) - traces.progress("lr", optimizer.param_groups[-1]['lr']) - traces.progress("gpumem_p0", th.cuda.max_memory_allocated()) - traces.progress("num_skipped_updates", model.skipped_steps) - traces.progress("scale", model.optimizer.cur_scale) - traces.progress("tokens_num", tokens_n) - traces.progress("time_elapsed", time.time() - t0) - iter_time = time.time() - t0_iter - if iter_time_last is None: - eta = (cfg['train_iters'] + 1 - iter_n) * iter_time - else: - eta = (cfg['train_iters'] + 1 - iter_n) * ((iter_time + iter_time_last) / 2) - traces.progress("eta_minutes", int(round(eta / 60))) - iter_time_last = iter_time - progress = traces.progress_dump(step=iter_n) - - if plot_process is not None: - plot_process.communicate() - plot_process = subprocess.Popen([ - sys.executable, - os.path.join(os.path.dirname(__file__), "traces_plot.py"), - "progress.jsonl", - "%d" % (cfg['train_iters'] + 50), - ], cwd=traces.context().path) - _save_checkpoint(force=False) - status_dict["worked_steps"] = iter_n - status_dict["worked_minutes"] = int((time.time() - t0) / 60) - status_dict["eta_minutes"] = int(round(eta / 60)) - save_status_json(status_dict, "working") - if "test_loss" in progress: - logging.info("finished iteration %d, train_loss=%0.3f, test_loss=%0.3f" - % (iter_n, progress["loss"], progress["test_loss"])) - else: - logging.info("finished iteration %d, train_loss=%0.3f" % (iter_n, progress["loss"])) - - -def finetune(status_dict, models_db: Dict[str, Any]): - logging.info("starting finetune at %s" % traces.context().path) - cfg = load_finetune_config(models_db) - traces.log("creating model...") - t0 = time.time() - model = make_model( - model_name=cfg['model_name'], - weights_path=cfg['model_info']['weight_path'], - repo_id=cfg['model_info']['repo_id'], - backend=cfg['model_info']['backend'], - freeze_exceptions=cfg['model_info']['freeze_exceptions'], - lora_target_modules=cfg['model_info']['lora']['lora_target_modules'], - lora_r=cfg['model_info']['lora']['lora_r'], - lora_alpha=cfg['model_info']['lora']['lora_alpha'], - lora_dropout=cfg['model_info']['lora']['lora_dropout'], - lora_init_scale=cfg['model_info']['lora']['lora_init_scale'], - dtype=th.bfloat16 if 'bf16' in cfg and cfg['bf16']['enabled'] else th.float16, - init_device="cuda", - device="cuda", - ) - t1 = time.time() - traces.log("/model %0.1fms" % ((t1 - t0) * 1000)) - if cfg['debug']: - summary(model, depth=4, col_names=['num_params', 'params_percent', 'trainable']) - model, optimizer, _, _ = deepspeed.initialize( - config=cfg, - model=model, - model_parameters=[p for p in model.parameters() if p.requires_grad], - dist_init_required=True - ) - train_ds, test_ds = create_data(cfg['model_name'], cfg, model.encoding) - loop( - cfg=cfg, - model=model, - optimizer=optimizer, - loss_function=partial( - masked_loss, average_elements=cfg['model_info']['loss_average_elements'], - enc=model.encoding - ), - model_name=cfg['model_name'], - train_ds=train_ds, - test_ds=test_ds, - status_dict=status_dict - ) - logging.info("finished finetune at %s" % traces.context().path) - - -def main(models_db: Dict[str, Any]): - status_dict = { - "started_ts": time.time(), - "worked_steps": 0, - "worked_minutes": 0, - "status": "starting", - "quality": "unknown" - } - save_status_json(status_dict, "working") - - def catch_sigusr1(signum, frame): - logging.info("catched SIGUSR1, interrupted") - traces.log("Interrupted") - status_dict["error"] = "interrupted" - save_status_json(status_dict, "interrupted") - exit(99) - - signal.signal(signal.SIGUSR1, catch_sigusr1) - try: - finetune(status_dict, models_db) - save_status_json(status_dict, "finished") - except SystemExit: - # catched sigusr1, interrupt by watchdog - exit(99) # this has to be there, even if catch_sigusr1() already called exit with 99, otherwise exit code is zero - except BaseException as e: # BaseException includes KeyboardInterrupt - if "error" not in status_dict: # if there is, a more detailed error is already in place - t = str(e) or str(type(e)) - status_dict["error"] = t - logging.error("FAILED: %s" % t) - traces.log("FAILED: %s" % t) - save_status_json(status_dict, "failed") - logging.error("FAILED finetune at %s" % traces.context().path) - logging.error("Error was: %s" % status_dict["error"]) - raise e - - -if __name__ == "__main__": - from known_models_db.refact_known_models import models_mini_db - - YMD_hms = os.environ.get("LORA_LOGDIR", "") or time.strftime("lora-%Y%m%d-%H%M%S") - traces.configure(task_dir="loras", task_name=YMD_hms, work_dir=env.PERMDIR) - main(models_mini_db) diff --git a/refact_data_pipeline/finetune/model_handling.py b/refact_data_pipeline/finetune/model_handling.py deleted file mode 100644 index 651007ab..00000000 --- a/refact_data_pipeline/finetune/model_handling.py +++ /dev/null @@ -1,262 +0,0 @@ -import importlib -from collections import deque -from functools import partial -from pathlib import Path - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from refact_data_pipeline.finetune import supported_models -from refact_encoding import RefactEncoding - -import torch as th -import torch.nn.functional as F - -from refact_models.codify_model import CodifyModel -from refact_models.checkpoint_loader import load_config - -from typing import List, Tuple, Optional - -from refact_models.lora import LoraMixin - -unmasked_avg_buf = None - - -def masked_loss( - logits: th.Tensor, - labels: th.Tensor, - mask: Optional[th.Tensor] = None, - *, - average_elements: int, - enc: RefactEncoding, - debug_dump: Optional[List[str]] = None -) -> th.Tensor: - def _average(one_d_tensor: th.Tensor) -> th.Tensor: - global unmasked_avg_buf - if unmasked_avg_buf is None: - unmasked_avg_buf = deque(maxlen=average_elements) - if th.is_grad_enabled(): - for x in one_d_tensor: - unmasked_avg_buf.append(float(x)) - return sum(unmasked_avg_buf) / len(unmasked_avg_buf) - else: - return one_d_tensor.to(th.float32).mean().item() - - _mb, _T = labels.shape - mb, T, U = logits.shape - assert _T == T - assert _mb == mb - - ce = F.cross_entropy( - logits.reshape(mb * T, U), - labels.reshape(mb * T), - reduction="none" - ).reshape(mb, T) - avg_mask_sum = _average(mask.sum(dim=1)) - loss_ce = ((ce * mask).sum(dim=1) / avg_mask_sum).mean() - - if debug_dump is not None: - import termcolor - def token_str(x, cond, color): - t = "\"" + enc.decode([x]).replace("\n", "\\n") + "\"" - if cond: - return termcolor.colored(t, color) - else: - return t - - with th.no_grad(): - b = 0 - for ti in range(T): - if b == -1: - continue - if ti & 15 == 0: - debug_dump.append("-----") - largest_logit_n = logits[b, ti].argmax().item() - debug_dump.append(" ".join([ - "%04i" % (ti,), - "ce=%5.2f" % ce[b, ti].item(), - "label=%-20s" % token_str(labels[b, ti].item(), mask[b, ti].item(), "green"), - "mask=%i" % mask[b, ti].item(), - "largest_logit=%05i" % largest_logit_n, - "modelthinks=%-10s" % token_str(largest_logit_n, - (mask[b, ti].item() and labels[b, ti].item() != largest_logit_n), - "red"), - ])) - debug_dump.append("-- (ce * mask).sum(dim=1) = %s" % (ce * mask).sum(dim=1)) - debug_dump.append("-- avg_mask_sum = %s" % avg_mask_sum) - debug_dump.append("-- this example loss_ce = %5.3f" % loss_ce.item()) - - return loss_ce - - -def freeze_model( - model: th.nn.Module, - freeze_exceptions: List[str] -) -> th.nn.Module: - for name, p in model.named_parameters(): - if any([e in name for e in freeze_exceptions]): - p.requires_grad_(True) - else: - p.requires_grad_(False) - return model - - -def save_model_state(model, save_path, tag): - keys_white_list = { - 'module', 'buffer_names', 'optimizer', 'param_shapes', 'frozen_param_shapes', - 'lr_scheduler', 'data_sampler', 'random_ltd', 'sparse_tensor_module_names', - 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', - 'ds_config', 'ds_version' - } - - model.save_checkpoint(save_path, tag=tag) - cp_path = Path(save_path) / tag - model_cps = [p for p in cp_path.iterdir() if 'model_states' in p.name] - _ = [p.unlink() for p in cp_path.iterdir() if 'model_states' not in p.name] - for cp_path in model_cps: - cp = th.load(str(cp_path), map_location='cpu') - cp = {k: v for k, v in cp.items() if k in keys_white_list} - th.save(cp, str(cp_path)) - - -def setup_encoding( - model_name: str, - weights_path: str, - repo_id: str -): - model_config = supported_models.config[model_name] - if "tokenizer" in model_config: - encoding = AutoTokenizer.from_pretrained( - repo_id, cache_dir=weights_path, - trust_remote_code=True - ) - encoding.encode_stochastic = lambda x, *args, **kwargs: (encoding.encode(x), None) - encoding.decode_utf8 = lambda x, *args, **kwargs: encoding.decode(x) - else: - encoding = RefactEncoding( - load_config(root_path=weights_path, repo_id=repo_id).enc_name - ) - encoding.EOT = model_config["tokenizer"]["eot_idx"] - encoding.DIAMOND = model_config["tokenizer"]["padding_idx"] - encoding.PREFIX = model_config["tokenizer"]["fim_prefix"] - encoding.INFILL = model_config["tokenizer"]["fim_middle"] - encoding.SUFFIX = model_config["tokenizer"]["fim_suffix"] - encoding.ESCAPE = model_config["tokenizer"]["escape"] - return encoding - - -def model_forward( - model: th.nn.Module, - input: th.Tensor, - low_gpu_mem_mode: bool, - backend: str -) -> th.Tensor: - if backend == "transformers": - if low_gpu_mem_mode: - model.gradient_checkpointing_enable() - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - else: - model.gradient_checkpointing_disable() - logits = model.forward( - input, - return_dict=False, output_attentions=False, output_hidden_states=False - )[0] - else: - if low_gpu_mem_mode: - logits = model.forward_train_cp(input) - else: - logits = model.lm_forward(model(input, attention_mask=None)[0]) - return logits - - -def _lora_state_dict(model, *args, destination=None, prefix='', keep_vars=False, layer_names): - return { - name: p - for name, p in model.old_state_dict( - *args, destination=destination, prefix=prefix, keep_vars=keep_vars - ).items() - if any(n in name for n in layer_names) - } - - -def setup_model_specific_params( - model_name: str, - freeze_exceptions: List[str], - lora_target_modules: List[str] -) -> Tuple[List[str], List[str]]: - assert model_name in supported_models.config - model_config = supported_models.config[model_name] - freeze_exceptions = [model_config["freeze_exceptions_mapping"][e] for e in freeze_exceptions] - lora_target_modules_mapping = [m for modules in lora_target_modules - for m in model_config["lora_target_modules_mapping"][modules]] - return list(set(freeze_exceptions)), list(set(lora_target_modules_mapping)) - - -def _apply_model_modifiers(model: th.nn.Module, modifiers: List[str]): - for modifier in modifiers: - path, modifier_name = modifier.rsplit('.', maxsplit=1) - mod_path = importlib.import_module(f"refact_data_pipeline.finetune.{path}") - mod = getattr(mod_path, modifier_name) - mod(model) - - -def make_model( - model_name: str, - weights_path: str, - repo_id: str, - backend: str, - *, - freeze_exceptions: List[str], - lora_target_modules: List[str], - lora_r: int, - lora_alpha: int, - lora_dropout: float, - lora_init_scale: float, - dtype: th.dtype, - init_device: str = "cpu", - device: str = "cuda", -) -> th.nn.Module: - # init_device CPU is to save memory - encoding = setup_encoding(model_name, weights_path, repo_id) - freeze_exceptions, lora_target_modules = setup_model_specific_params( - model_name, freeze_exceptions, lora_target_modules - ) - if backend == "legacy": - model = CodifyModel.from_pretrained( - weights_path, device=init_device, repo_id=repo_id - ).to(dtype) - _apply_model_modifiers(model, supported_models.config[model_name]['train_model_modifiers']) - elif backend == "transformers": - model = AutoModelForCausalLM.from_pretrained( - repo_id, cache_dir=weights_path, - device_map=init_device, torch_dtype=dtype, - trust_remote_code=True - ) - model.encoding = encoding - _apply_model_modifiers(model, supported_models.config[model_name]['train_model_modifiers']) - else: - raise ValueError("Unknown backend") - - LoraMixin.apply_lora( - model.to(device), - lora_target_modules=lora_target_modules, - lora_r=int(lora_r), - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - lora_init_scale=lora_init_scale - ) - model = freeze_model( - model, - freeze_exceptions=freeze_exceptions - ) - model.old_state_dict = model.state_dict - model.state_dict = partial( - _lora_state_dict.__get__(model, type(model)), - layer_names=freeze_exceptions - ) - model = model.to(dtype) - model = model.cuda() - return model diff --git a/refact_data_pipeline/finetune/supported_models.py b/refact_data_pipeline/finetune/supported_models.py deleted file mode 100644 index d8c3e131..00000000 --- a/refact_data_pipeline/finetune/supported_models.py +++ /dev/null @@ -1,155 +0,0 @@ -config = { - "Refact/1.6B": { - "lora_target_modules_mapping": { - "qkv": ["attn.q", "attn.kv"], - "out": ["attn.c_proj"], - "backproj": ["attn.c_proj"], - "mlp": ["mlp.gate_up_proj", "mlp.c_proj"], - }, - "freeze_exceptions_mapping": { - "wte": "wte", - "lm_head": "lm_head", - "lora": "lora" - }, - "tokenizer": { - "eot_idx": 0, - "padding_idx": 4, - "fim_prefix": 1, - "fim_middle": 2, - "fim_suffix": 3, - "escape": 14 - }, - "train_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42", - "pipeline_name": "local_fim" - }, - "test_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42," - "pack_single=1,pack_complete=0,pack_buffer_size=25," - "quit_on_epoch=1,seed=42", - "pipeline_name": "local_fim" - }, - "train_model_modifiers": [ - "sa.apply_flash_mha_to_refact_model" - ], - "force_enable_checkpointing": False - }, - - "starcoder/1b/base": { - "lora_target_modules_mapping": { - "qkv": ["attn.q_attn", "attn.c_attn"], - "out": ["attn.c_proj"], - "backproj": ["attn.c_proj"], - "mlp": ["mlp.c_fc", "mlp.c_proj"], - }, - "freeze_exceptions_mapping": { - "wte": "wte", - "lm_head": "lm_head", - "lora": "lora" - }, - "tokenizer": { - "eot_idx": 0, - "padding_idx": 4, - "fim_prefix": 1, - "fim_middle": 2, - "fim_suffix": 3, - "escape": 14 - }, - "train_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42", - "pipeline_name": "local_fim" - }, - "test_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42," - "pack_single=1,pack_complete=0,pack_buffer_size=25," - "quit_on_epoch=1,seed=42", - "pipeline_name": "local_fim" - }, - "train_model_modifiers": [], - "force_enable_checkpointing": True - }, - - "starcoder/3b/base": { - "lora_target_modules_mapping": { - "qkv": ["attn.q_attn", "attn.c_attn"], - "out": ["attn.c_proj"], - "backproj": ["attn.c_proj"], - "mlp": ["mlp.c_fc", "mlp.c_proj"], - }, - "freeze_exceptions_mapping": { - "wte": "wte", - "lm_head": "lm_head", - "lora": "lora" - }, - "tokenizer": { - "eot_idx": 0, - "padding_idx": 4, - "fim_prefix": 1, - "fim_middle": 2, - "fim_suffix": 3, - "escape": 14 - }, - "train_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42,seed=42", - "pipeline_name": "local_fim" - }, - "test_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42," - "pack_single=1,pack_complete=0,pack_buffer_size=25," - "quit_on_epoch=1,seed=42", - "pipeline_name": "local_fim" - }, - "train_model_modifiers": [], - "force_enable_checkpointing": True - }, - - "starcoder/7b/base": { - "lora_target_modules_mapping": { - "qkv": ["attn.q_attn", "attn.c_attn"], - "out": ["attn.c_proj"], - "backproj": ["attn.c_proj"], - "mlp": ["mlp.c_fc", "mlp.c_proj"], - }, - "freeze_exceptions_mapping": { - "wte": "wte", - "lm_head": "lm_head", - "lora": "lora" - }, - "tokenizer": { - "eot_idx": 0, - "padding_idx": 4, - "fim_prefix": 1, - "fim_middle": 2, - "fim_suffix": 3, - "escape": 14 - }, - "train_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42", - "pipeline_name": "local_fim" - }, - "test_ds_pipeline": { - "ds_opts": "n_ctx={n_ctx},fim_probability=0.9,fim_drop_residual=1," - "tkr_stochastic_tokens=3,shuffle_depth=3000,debug=0," - "random_trim_context_prob=0.01,fim_random_seed=42," - "pack_single=1,pack_complete=0,pack_buffer_size=25," - "quit_on_epoch=1,seed=42", - "pipeline_name": "local_fim" - }, - "train_model_modifiers": [], - "force_enable_checkpointing": True - } -} diff --git a/refact_data_pipeline/finetune_datasource.py b/refact_data_pipeline/finetune_datasource.py index 963c14a4..7cf25933 100644 --- a/refact_data_pipeline/finetune_datasource.py +++ b/refact_data_pipeline/finetune_datasource.py @@ -1,45 +1,48 @@ import os -import jsonlines import random +from pathlib import Path +from typing import Iterable, Dict, Any, List + +import jsonlines +import numpy as np +import torch.utils.data -from refact_data_pipeline.filters_fim_v2 import FIMv2 -from refact_encoding import RefactEncoding -from refact_encoding import hlprint -from refact_data_pipeline import filters_synthetic from refact_data_pipeline import DatasetOpts from refact_data_pipeline import pipeline_pieces as pp +from refact_data_pipeline.filters_fim_v2 import FIMv2 from self_hosting_machinery import env -from typing import Union, List - - -def cut_zip_name(j): - p = j["path"] - slash_pos = p.find("/") - if slash_pos != -1: - p = p[slash_pos+1:] - return p +__all__ = [ + 'RefactDataset', 'RefactPlainCodeDataset', 'RefactFIMCodeDataset' +] class ReadFileByFile: def __init__( - self, - js, - dataopts: DatasetOpts, + self, + inner_filter: Iterable[Dict[str, Any]], + dataopts: DatasetOpts, ): - self.js = js + self.inner_filter = inner_filter self.dataopts = dataopts self.quit_on_epoch = dataopts.get("quit_on_epoch", 0) + @staticmethod + def _cut_zip_name(j): + p = j["path"] + slash_pos = p.find("/") + if slash_pos != -1: + p = p[slash_pos + 1:] + return p + def __iter__(self): file_num = 0 epoch = 0 while 1: - for j in self.js: - # print("READING", j["path"]) + for j in self.inner_filter: code = open(os.path.join(env.DIR_UNPACKED, j["path"]), encoding="utf-8").read() yield { - "path": cut_zip_name(j), + "path": ReadFileByFile._cut_zip_name(j), "code": code, "text": code, "size": len(code), @@ -55,11 +58,13 @@ def __iter__(self): class CodeToPrefixCompletion: - def __init__(self, - inner_filter, - dataopts: DatasetOpts, - ): + def __init__( + self, + inner_filter: Iterable[Dict[str, Any]], + dataopts: DatasetOpts, + ): self.inner_filter = inner_filter + self.dataopts = dataopts def __iter__(self): for j in self.inner_filter: @@ -70,91 +75,73 @@ def __iter__(self): } -def local_infill(fn_set_jsonl, dataopts): - rank = 0 - size = 1 - js = list(jsonlines.open(os.path.join(env.DIR_UNPACKED, fn_set_jsonl))) - fixed_seed_random = random.Random(42) - fixed_seed_random.shuffle(js) - ds = ReadFileByFile(js, dataopts) - ds = pp.SplitRanks(ds, dataopts, commrank=rank, commsize=size) - ds = filters_synthetic.InfillDiff(ds, dataopts) - ds = pp.Packer(ds, dataopts, keys=["tokens", "mask", "first"], force16=True, force_pack_complete=True) - ds = pp.Shuffle(ds, dataopts) - return ds - - -def local_plain(fn_set_jsonl: Union[str, List[str]], dataopts): - rank = 0 - size = 1 - if isinstance(fn_set_jsonl, str): - js = list(jsonlines.open(os.path.join(env.DIR_UNPACKED, fn_set_jsonl))) - else: - js = fn_set_jsonl - fixed_seed_random = random.Random(43) - fixed_seed_random.shuffle(js) - ds = ReadFileByFile(js, dataopts) - ds = pp.SplitRanks(ds, dataopts, commrank=rank, commsize=size) # this drops some of the data {"code": ...} at each rank - ds = CodeToPrefixCompletion(ds, dataopts) - ds = pp.Tokenizer(ds, dataopts) - ds = pp.PromptCompletionToTokensMask(ds, dataopts) - ds = pp.Packer(ds, dataopts, keys=["tokens", "mask", "first"]) - ds = pp.Shuffle(ds, dataopts) - return ds - - -def local_fim(fn_set_jsonl, dataopts): - rank = 0 - size = 1 - if isinstance(fn_set_jsonl, str): - js = list(jsonlines.open(os.path.join(env.DIR_UNPACKED, fn_set_jsonl))) - else: - js = fn_set_jsonl - fixed_seed_random = random.Random(43) - fixed_seed_random.shuffle(js) - ds = ReadFileByFile(js, dataopts) - ds = pp.SplitRanks(ds, dataopts, commrank=rank, commsize=size) # this drops some of the data {"code": ...} at each rank - ds = FIMv2(ds, dataopts) - ds = pp.DensePacker(ds, dataopts) - ds = pp.Shuffle(ds, dataopts) - return iter(ds) - - -def local_mix_plain_infill(fn_set_jsonl, dataopts): - return pp.Mix([ - local_plain(fn_set_jsonl, dataopts), - local_infill(fn_set_jsonl, dataopts), - ], (0.75, 0.25)) - - -def local_sequence_plain_infill(fn_set_jsonl, dataopts): - ds1 = local_plain(fn_set_jsonl, dataopts) - ds2 = local_infill(fn_set_jsonl, dataopts) - def _iter(): - for ex1 in ds1: - yield ex1 - for ex2 in ds2: - yield ex2 - return _iter() - - -def print_data_feed(is_test_set): - enc = RefactEncoding("openai_programming_v2") - if is_test_set: - dataopts = DatasetOpts("n_ctx=2049,quit_on_epoch=1,seed=1337") - dataopts.set_encoding(enc) - ds = local_sequence_plain_infill("test_set.jsonl", dataopts) - else: - dataopts = DatasetOpts("n_ctx=2049,seed=1337") - dataopts.set_encoding(enc) - ds = local_mix_plain_infill("train_set.jsonl", dataopts) - cnt = 0 - for ex in ds: - print(hlprint(enc, ex["tokens"], ex["mask"])) - cnt += 1 - if cnt == 10: - break - - -if __name__ == '__main__': - print_data_feed(is_test_set=False) +class RefactDataset(torch.utils.data.IterableDataset): + def __init__( + self, + files: List[Dict[str, Any]], + dataset_options: str, + encoding: 'Encoding' + ): + self._files = files + self._ds_options = DatasetOpts(dataset_options) + self._encoding = encoding + self._ds_options.set_encoding(self._encoding) + + @staticmethod + def from_a_single_file( + cls, + file: Dict[str, Any], + dataset_options: str, + encoding: 'Encoding' + ) -> 'RefactDataset': + return cls([file], dataset_options, encoding) + + @staticmethod + def from_a_jsonl( + cls, + jsonl_path: str, + dataset_options: str, + encoding: 'Encoding' + ) -> 'RefactDataset': + files = list(jsonlines.open(Path(env.DIR_UNPACKED) / jsonl_path)) + return cls(files, dataset_options, encoding) + + @property + def files_len(self) -> int: + return len(self._files) + + def _get_files_by_worker(self) -> List[Dict[str, Any]]: + files = self._files + random.Random(self._ds_options.get("seed", 42)).shuffle(files) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + assert len(files) > 1, "It doesn't work with 1 file in multiprocessing mode" + assert len(files) > worker_info.num_workers, "YO have to have more files to process than processes" + files = np.array_split(files, worker_info.num_workers)[worker_info.id] + return files + + def _build_pipeline(self, files: List[Dict[str, Any]]): + raise NotImplementedError() + + def __iter__(self): + return iter(self._build_pipeline(self._get_files_by_worker())) + + +class RefactPlainCodeDataset(RefactDataset): + def _build_pipeline(self, files: List[Dict[str, Any]]): + ds = ReadFileByFile(files, self._ds_options) + ds = CodeToPrefixCompletion(ds, self._ds_options) + ds = pp.Tokenizer(ds, self._ds_options) + ds = pp.PromptCompletionToTokensMask(ds, self._ds_options) + ds = pp.DensePacker(ds, self._ds_options) + ds = pp.Shuffle(ds, self._ds_options) + return ds + + +class RefactFIMCodeDataset(RefactDataset): + def _build_pipeline(self, files: List[Dict[str, Any]]): + ds = ReadFileByFile(files, self._ds_options) + ds = FIMv2(ds, self._ds_options) + ds = pp.DensePacker(ds, self._ds_options) + ds = pp.Shuffle(ds, self._ds_options) + return ds diff --git a/refact_data_pipeline/pipeline_pieces.py b/refact_data_pipeline/pipeline_pieces.py index d7e16806..bd001d3d 100644 --- a/refact_data_pipeline/pipeline_pieces.py +++ b/refact_data_pipeline/pipeline_pieces.py @@ -1,37 +1,37 @@ -import os -import ujson -import filelock -import itertools import copy -import psutil -import random import datetime +import gzip +import itertools +import os +import random import traceback +from typing import List, Union, Iterable +import blobfile as bf +import filelock +import ujson +import zstandard from mpi4py import MPI -from refact_encoding import RefactEncoding -from refact_data_pipeline.datadef import DatasetOpts from refact_data_pipeline.datadef import DatasetDef, DatasetDumpedDef from refact_data_pipeline.datadef import DatasetMix +from refact_data_pipeline.datadef import DatasetOpts from refact_data_pipeline.filters_hdfs import Hdf5Dataset from refact_data_pipeline.filters_packing import Packer, SinglePacker, DensePacker - -from typing import Dict, List, Union, Iterable, Any - +from refact_encoding import RefactEncoding log = print class JsonlFilesReaderCached: def __init__(self, - dataopts: DatasetOpts, - cloud_path: str, - cloud_files: str, - datarank: int, - cold_restart_key: int, - cold_restart_skip: int, - ): + dataopts: DatasetOpts, + cloud_path: str, + cloud_files: str, + datarank: int, + cold_restart_key: int, + cold_restart_skip: int, + ): self.cloud_path = cloud_path self.cloud_files = cloud_files self.datarank = datarank @@ -40,9 +40,6 @@ def __init__(self, self.cold_restart_skip = cold_restart_skip def __iter__(self): - import blobfile as bf - import zstandard - import gzip record_n = 0 stats = {} short_path = "/".join(self.cloud_path.rstrip("/").split("/")[2:]) @@ -59,14 +56,14 @@ def __iter__(self): stats["task_dir"] = cache_dir stats["reading_fn"] = cached_fn stats["file_fn"] = fn - position = epoch*len(self.cloud_files) + i + position = epoch * len(self.cloud_files) + i stats["file_n"] = position stats["file_N"] = len(self.cloud_files) - stats["file_n_over_N"] = (epoch*len(self.cloud_files) + i) / len(self.cloud_files) + stats["file_n_over_N"] = (epoch * len(self.cloud_files) + i) / len(self.cloud_files) ymd_hms = datetime.datetime.now().strftime("%Y%m%d %H:%M:%S") log(ymd_hms, "epoch %i reading %i/%i %s" % (epoch, i, len(self.cloud_files), cached_fn)) skipped += 1 - if self.cold_restart_skip > 0 and skipped < self.cold_restart_skip + 2: # one because it's the same we were reading, and another one for good measure + if self.cold_restart_skip > 0 and skipped < self.cold_restart_skip + 2: # one because it's the same we were reading, and another one for good measure log("skipped %i" % skipped) continue stats["restart%02d" % self.cold_restart_key] = position @@ -76,7 +73,7 @@ def __iter__(self): if os.path.exists(cached_fn): pass # This is useful to understand which files are being processed: - #log("using cached '%s'" % cached_fn) + # log("using cached '%s'" % cached_fn) else: log("downloading '%s' from '%s'" % (cached_fn, self.cloud_path + fn)) bf.copy(self.cloud_path + fn, cached_fn + ".tmp") @@ -103,6 +100,7 @@ def bin2str(buffer_bytes): buffer = b"" yield line.decode("utf8") + "\n" buffer += lines[-1] + it = bin2str(1 << 20) else: it = open(cached_fn) @@ -125,11 +123,12 @@ def bin2str(buffer_bytes): class SplitRanks: - def __init__(self, - inner_filter, - dataopts: DatasetOpts, - commrank: int, - commsize: int, + def __init__( + self, + inner_filter, + dataopts: DatasetOpts, + commrank: int, + commsize: int, ): self.inner_filter = inner_filter self.commrank = commrank @@ -141,31 +140,25 @@ def __iter__(self): yield rec -def predictable_files_shuffle(lst): - """ - Seed rng to fixed value - """ - fixed_seed_random = random.Random(42) - fixed_seed_random.shuffle(lst) - return lst - - class Tokenizer: def __init__(self, - inner_filter, - dataopts: DatasetOpts, - ): + inner_filter, + dataopts: DatasetOpts, + ): self.inner_filter = inner_filter self.skip_prompt_len: int = dataopts.get("tkr_skip_long_prompt", 0) self.skip_completion_len: int = dataopts.get("tkr_skip_completion_len", 0) self.skip_total_len: int = dataopts.get("tkr_skip_total_len", -1) if self.skip_total_len == -1: - self.skip_total_len = 2**31 + self.skip_total_len = 2 ** 31 self.fatal_skip: bool = dataopts.get("tkr_fatal_skip", 0) == 1 self.append_eot: bool = dataopts.get("tkr_append_eot", 1) == 1 self.tkr_stochastic_tokens = dataopts.get("tkr_stochastic_tokens", 0) self.tkr_rm_bos_in_completion: int = dataopts.get("tkr_rm_bos_in_completion", 0) + self.random_seed: int = dataopts.get("seed", 42) self.enc = dataopts.encoding + if hasattr(self.enc, "set_random_seed"): + self.enc.set_random_seed(self.random_seed) self.stats = { "tkr_skip_prompt_len": 0, "tkr_skip_completion_len": 0, @@ -176,8 +169,13 @@ def __init__(self, def __iter__(self): for ex in self.inner_filter: if self.tkr_stochastic_tokens > 0: - prompt_tokens, _ = self.enc.encode_stochastic(ex["prompt"], [], 0.01*self.tkr_stochastic_tokens) - completion_tokens, _ = self.enc.encode_stochastic(ex["completion"], [], 0.01*self.tkr_stochastic_tokens) + if hasattr(self.enc, 'encode_stochastic'): + prompt_tokens, _ = self.enc.encode_stochastic(ex["prompt"], [], 0.01 * self.tkr_stochastic_tokens) + completion_tokens, _ = self.enc.encode_stochastic(ex["completion"], [], + 0.01 * self.tkr_stochastic_tokens) + else: + prompt_tokens = self.enc.encode(ex["prompt"]) + completion_tokens = self.enc.encode_stochastic(ex["completion"]) else: prompt_tokens = self.enc.encode(ex["prompt"]) completion_tokens = self.enc.encode(ex["completion"]) @@ -205,9 +203,9 @@ def __iter__(self): class PromptCompletionToTokensMask: def __init__(self, - inner_filter, - dataopts: DatasetOpts, - ): + inner_filter, + dataopts: DatasetOpts, + ): self.inner_filter = inner_filter def __iter__(self): @@ -215,23 +213,23 @@ def __iter__(self): ln = len(rec["prompt_tokens"]) + len(rec["completion_tokens"]) yield { "tokens": rec["prompt_tokens"] + rec["completion_tokens"], - "mask": [0]*len(rec["prompt_tokens"]) + [1]*len(rec["completion_tokens"]), - "first": [1] + [0]*(ln - 1), - "diffhlpoint": [0]*ln, # first position decision of a diff (no such thing for plain text) - "diffedits": [0]*ln, # 0 don't learn (1 no edit, 2 edit) + "mask": [0] * len(rec["prompt_tokens"]) + [1] * len(rec["completion_tokens"]), + "first": [1] + [0] * (ln - 1), + "diffhlpoint": [0] * ln, # first position decision of a diff (no such thing for plain text) + "diffedits": [0] * ln, # 0 don't learn (1 no edit, 2 edit) "stats": rec["stats"], } class Shuffle: - def __init__(self, - inner_filter, - dataopts: DatasetOpts, + def __init__( + self, + inner_filter, + dataopts: DatasetOpts, ): self.inner_filter = inner_filter self.shuffle_depth: int = dataopts.get("shuffle_depth", 1000) - self.seed = dataopts.get("seed", 0) - self.random_state = random.Random(self.seed if self.seed else None) + self.random_state = random.Random(dataopts.get("seed", 42)) def __iter__(self): buf = [] @@ -246,12 +244,18 @@ def __iter__(self): class Mix: - def __init__(self, src: List[Iterable], proportions: List[float], seed: int = 42, shuffle_depth : int = 1000): + def __init__( + self, + src: List[Iterable], + proportions: List[float], + seed: int, + shuffle_depth: int = 1000, + ): self.src = src - self.proportions = proportions if len(proportions) == len(src) else [1/len(src)]*len(src) + self.proportions = proportions if len(proportions) == len(src) else [1 / len(src)] * len(src) self.seed = seed self.shuffle_depth: int = shuffle_depth - self.random_state = random.Random(self.seed if self.seed else None) + self.random_state = random.Random(self.seed) assert abs(sum(self.proportions) - 1) < 0.0000001 def __iter__(self): @@ -275,26 +279,35 @@ def __iter__(self): def build_filter_stack( - datadef: Union[DatasetDef, DatasetMix], - dataopts: DatasetOpts, - enc: RefactEncoding, - comm: MPI.Comm, - cold_restart: List[int] = [], - cold_restart_offset = 0, - skip_assert_flag: bool = False + datadef: Union[DatasetDef, DatasetMix], + dataopts: DatasetOpts, + enc: RefactEncoding, + comm: MPI.Comm, + cold_restart: List[int] = [], + cold_restart_offset: int = 0, + skip_assert_flag: bool = False ): dataopts.set_encoding(enc) if isinstance(datadef, DatasetMix): if len(cold_restart) == 0: - cold_restart = [0]*comm.size*len(datadef.dataset_defs) + cold_restart = [0] * comm.size * len(datadef.dataset_defs) sources = [] for i, dsdef in enumerate(datadef.dataset_defs): - cold_restart_offset = i*comm.size - src = build_filter_stack(dsdef, dataopts, enc, comm, cold_restart, cold_restart_offset, skip_assert_flag=True) + cold_restart_offset = i * comm.size + src = build_filter_stack( + datadef=dsdef, + dataopts=dataopts, + enc=enc, + comm=comm, + cold_restart=cold_restart, + cold_restart_offset=cold_restart_offset, + skip_assert_flag=True + ) sources.append(src) - return Mix(sources, datadef.proportions) + return Mix(sources, datadef.proportions, seed=dataopts.get("seed", 42)) + if len(cold_restart) == 0: - cold_restart = [0]*comm.size + cold_restart = [0] * comm.size path = datadef.cloud_path files_len = len(datadef.cloud_files) @@ -313,14 +326,21 @@ def build_filter_stack( ds = None for filt in datadef.to_apply: if ds is None and filt == "jsonl": - ds = JsonlFilesReaderCached(dataopts, path, my_files, datarank=comm.rank, + ds = JsonlFilesReaderCached( + dataopts, + path, + my_files, + datarank=comm.rank, cold_restart_key=cold_restart_offset + comm.rank, cold_restart_skip=cold_restart[cold_restart_offset + comm.rank], - ) + ) elif ds is None and filt == 'hdfs': - ds = Hdf5Dataset(dataopts, my_files, comm=comm, + ds = Hdf5Dataset( + dataopts, + my_files, + comm=comm, cold_restart_skip=cold_restart[cold_restart_offset + comm.rank], - ) + ) elif filt == "splitranks": ds = SplitRanks(ds, dataopts, commrank=comm.rank, commsize=comm.size) elif ds and filt == "tokenize": diff --git a/refact_encoding/encoding.py b/refact_encoding/encoding.py index 3da24fd9..a00ee5ff 100644 --- a/refact_encoding/encoding.py +++ b/refact_encoding/encoding.py @@ -13,7 +13,7 @@ class RefactEncoding: - def __init__(self, name: str): + def __init__(self, name: str, random_seed: int = 42): self.DIAMOND = 0 self.INFILL = 0 self.ESCAPE = 0 @@ -31,6 +31,7 @@ def __init__(self, name: str): self._sentencepiece_tokenizer = None self._allowed_special = set() self._slash_n_banlist = set() + self._random = random.Random(random_seed) if name in ["openai_reversible50000"]: self.EOT = 50256 @@ -191,6 +192,9 @@ def __init__(self, name: str): for t in range(self.n_vocab): self._token2bytes[t] = self._tik.decode_bytes([t]) + def set_random_seed(self, random_seed: int): + self._random = random.Random(random_seed) + def decode_utf8(self, tokens) -> str: if self._tokenizer: if len(tokens) == 1: @@ -254,7 +258,7 @@ def encode_stochastic(self, sequence, bounds_at: List[int], prob: float) -> Tupl bounds_at = list(set(bounds_at)) bounds_at.sort() else: - bounds_set = set([random.randint(0, len(sequence) - 1) + bounds_set = set([self._random.randint(0, len(sequence) - 1) for _ in range(bounds_n)]) bounds_set.add(len(sequence)) bounds_set.add(0) diff --git a/refact_models/checkpoint_loader.py b/refact_models/checkpoint_loader.py index c9be385d..6669ea93 100644 --- a/refact_models/checkpoint_loader.py +++ b/refact_models/checkpoint_loader.py @@ -111,11 +111,11 @@ def load_checkpoint(model, root_path: str, repo_id: Optional[str] = None): def load_finetune_checkpoint(model, model_name: str, root_path: str, repo_id: Optional[str] = None): - from refact_data_pipeline.finetune.model_handling import setup_model_specific_params + from self_hosting_machinery.finetune.modelling.model_handling import map_model_specific_params finetune_cp = _load_filename(root_path, 'mp_rank_00_model_states.pt', repo_id) lora_cfg = finetune_cp['ds_config']['model_info']['lora'] - _, lora_target_modules = setup_model_specific_params( + _, lora_target_modules = map_model_specific_params( model_name, lora_target_modules=lora_cfg.pop('lora_target_modules'), freeze_exceptions=[] ) LoraMixin.apply_lora( diff --git a/refact_scratchpads/scratchpad_2022q4_diff.py b/refact_scratchpads/scratchpad_2022q4_diff.py index 264d81b6..cfb9f1e3 100644 --- a/refact_scratchpads/scratchpad_2022q4_diff.py +++ b/refact_scratchpads/scratchpad_2022q4_diff.py @@ -467,7 +467,10 @@ def decodable_text_steps(token_idx: int) -> Tuple[int, int]: token_jdx = token_idx + 1 while token_jdx <= len(tokens): try: - text = self.enc.decode_utf8(tokens[token_idx:token_jdx]) + if hasattr(self.enc, 'decode_utf8'): + text = self.enc.decode_utf8(tokens[token_idx:token_jdx]) + else: + text = self.enc.decode(tokens[token_idx:token_jdx]) return token_jdx - token_idx, len(text) except UnicodeDecodeError: token_jdx += 1 diff --git a/refact_data_pipeline/finetune/__init__.py b/self_hosting_machinery/finetune/__init__.py similarity index 100% rename from refact_data_pipeline/finetune/__init__.py rename to self_hosting_machinery/finetune/__init__.py diff --git a/self_hosting_machinery/finetune/configuration/__init__.py b/self_hosting_machinery/finetune/configuration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/refact_data_pipeline/finetune/finetune_config.py b/self_hosting_machinery/finetune/configuration/finetune_config.py similarity index 98% rename from refact_data_pipeline/finetune/finetune_config.py rename to self_hosting_machinery/finetune/configuration/finetune_config.py index b6de1ed0..af0f1bc8 100644 --- a/refact_data_pipeline/finetune/finetune_config.py +++ b/self_hosting_machinery/finetune/configuration/finetune_config.py @@ -1,7 +1,7 @@ import math import torch -from refact_data_pipeline.finetune import traces +from self_hosting_machinery.finetune.utils import traces from self_hosting_machinery import env from typing import Any, Dict, List @@ -37,7 +37,7 @@ def base_config(model_name: str, models_db: Dict[str, Any]): ), debug=False, limit_time_seconds=48 * 60 * 60, - low_gpu_mem_mode=True, + low_gpu_mem_mode=False, save_every=10, test_every=1, train_iters=5, diff --git a/self_hosting_machinery/finetune/configuration/finetune_filtering_defaults.py b/self_hosting_machinery/finetune/configuration/finetune_filtering_defaults.py new file mode 100644 index 00000000..b4d34c6f --- /dev/null +++ b/self_hosting_machinery/finetune/configuration/finetune_filtering_defaults.py @@ -0,0 +1,5 @@ +finetune_filtering_defaults = { + "autoselect_test_files_num": 3, + "filter_loss_threshold": 3.0, + "debug": False +} diff --git a/refact_data_pipeline/finetune/finetune_train_defaults.py b/self_hosting_machinery/finetune/configuration/finetune_train_defaults.py similarity index 100% rename from refact_data_pipeline/finetune/finetune_train_defaults.py rename to self_hosting_machinery/finetune/configuration/finetune_train_defaults.py diff --git a/self_hosting_machinery/finetune/configuration/supported_models.py b/self_hosting_machinery/finetune/configuration/supported_models.py new file mode 100644 index 00000000..19176b14 --- /dev/null +++ b/self_hosting_machinery/finetune/configuration/supported_models.py @@ -0,0 +1,74 @@ +__all__ = ['config'] + +_fim_train_ds_pipeline = { + "ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=256," + "fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01", + "ds_name": "RefactFIMCodeDataset" +} + +_fim_test_ds_pipeline = { + "ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=0,quit_on_epoch=1," + "fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01," + "pack_single=1,pack_complete=0,pack_buffer_size=50", + "ds_name": "RefactFIMCodeDataset" +} +_bigcode_tokenizer_mapping = { + "eot_idx": 0, + "padding_idx": 4, + "fim_prefix": 1, + "fim_middle": 2, + "fim_suffix": 3, + "escape": 14 +} +_starcoder_base = { + "lora_target_modules_mapping": { + "qkv": ["attn.q_attn", "attn.c_attn"], + "out": ["attn.c_proj"], + "backproj": ["attn.c_proj"], + "mlp": ["mlp.c_fc", "mlp.c_proj"], + }, + "freeze_exceptions_mapping": { + "wte": "wte", + "lm_head": "lm_head", + "lora": "lora" + }, + "tokenizer": _bigcode_tokenizer_mapping, + "train_ds_pipeline": _fim_train_ds_pipeline, + "test_ds_pipeline": _fim_test_ds_pipeline, + "train_model_modifiers": [ + "flash_sa.apply_flash_mha_to_starcoder_model" + ], + "force_enable_checkpointing": False +} + +config = { + "Refact/1.6B": { + "lora_target_modules_mapping": { + "qkv": ["attn.q", "attn.kv"], + "out": ["attn.c_proj"], + "backproj": ["attn.c_proj"], + "mlp": ["mlp.gate_up_proj", "mlp.c_proj"], + }, + "freeze_exceptions_mapping": { + "wte": "wte", + "lm_head": "lm_head", + "lora": "lora" + }, + "tokenizer": _bigcode_tokenizer_mapping, + "train_ds_pipeline": _fim_train_ds_pipeline, + "test_ds_pipeline": _fim_test_ds_pipeline, + "train_model_modifiers": [ + "flash_sa.apply_flash_mha_to_refact_model" + ], + "force_enable_checkpointing": False + }, + + "starcoder/1b/base": _starcoder_base, + + "starcoder/3b/base": _starcoder_base, + + "starcoder/7b/base": { + **_starcoder_base, + "force_enable_checkpointing": True + } +} diff --git a/self_hosting_machinery/finetune/modelling/__init__.py b/self_hosting_machinery/finetune/modelling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/self_hosting_machinery/finetune/modelling/flash_sa.py b/self_hosting_machinery/finetune/modelling/flash_sa.py new file mode 100644 index 00000000..9e82b4bf --- /dev/null +++ b/self_hosting_machinery/finetune/modelling/flash_sa.py @@ -0,0 +1,138 @@ +import functools +import logging +import math + +import einops +import torch +from typing import Tuple, Optional + + +@functools.lru_cache(maxsize=2) +def generate_alibi( + max_seq_len: int, + num_attention_heads: int, + batch_size: Optional[int] = None, + use_flash_attn: bool = True, + tp_world_size: int = 1, + tp_index: int = 0 +) -> Tuple[torch.Tensor, float, float]: + def get_slopes(n): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + assert math.log2(n).is_integer( + ), "it works only when num_attention_heads is power of 2" + return get_slopes_power_of_2(n) + + slopes = torch.Tensor(get_slopes(num_attention_heads)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, -1, -1) + + # Select the part of the tensor that corresponds to our tensor parallel index. + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] + + if use_flash_attn: + alibi = alibi.unsqueeze(0).contiguous() + # (1, nheads, 1, seqlen_k) + else: + alibi = alibi.repeat(batch_size, 1, 1).contiguous() + + assert (num_attention_heads / tp_world_size).is_integer( + ), "it works only when (num_attention_heads/tp_world_size) is integer" + nh_tp = num_attention_heads // tp_world_size + alibi_ratio = (2 ** (-2 ** -(math.log2(num_attention_heads) - 3))) + alibi_start = (2 ** (-2 ** -(math.log2(num_attention_heads) - 3))) * alibi_ratio ** (nh_tp * tp_index) + + return alibi, alibi_start, alibi_ratio + + +def _prerequisites_are_ok(model): + try: + from flash_attn import flash_attn_func + return True + except ImportError: + logging.warning("Original flash attention is not installed, trying to use triton implementation...") + from self_hosting_machinery.finetune.modelling.triton_flash_sa import (apply_flash_mha_to_refact_model + as apply_triton_flash) + apply_triton_flash(model) + return False + + +def apply_flash_mha_to_refact_model(model): + if not _prerequisites_are_ok(model): + return + + from flash_attn import flash_attn_func + + def _forward( + self, + x: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False + ): + q = einops.rearrange(self.q(x), "b t (h d) -> b t h d", h=self.num_heads) + kv = einops.rearrange(self.kv(x), "b t (h d) -> b t h d", h=2) + k, v = kv.chunk(2, dim=2) + + _, alibi_start, alibi_ratio = generate_alibi(q.shape[1], self.num_heads) + attn_output = flash_attn_func( + q, k, v, softmax_scale=self.scale_factor, causal=True, + alibi=True, alibi_start=alibi_start, alibi_ratio=alibi_ratio + ) + + attn_output = einops.rearrange(attn_output, "b t h d -> b t (h d)") + attn_output = self.c_proj(attn_output) + return attn_output, None + + if torch.cuda.get_device_capability() < (8, 0): + logging.warning("Triton flash attention is not supported on gpus with cuda capability < 8") + return + + for block in model.transformer.h: + block.attn.forward = _forward.__get__(block.attn, type(block.attn)) + + +def apply_flash_mha_to_starcoder_model(model): + if not _prerequisites_are_ok(model): + return + + from flash_attn import flash_attn_func + + def _forward( + self, + x: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + qkv = self.c_attn(x) + q = einops.rearrange(qkv[:, :, :self.embed_dim], "b t (h d) -> b t h d", h=self.num_heads) + k = einops.rearrange(qkv[:, :, self.embed_dim:self.embed_dim + self.kv_dim], "b t (h d) -> b t h d", h=1) + v = einops.rearrange(qkv[:, :, self.embed_dim + self.kv_dim:], "b t (h d) -> b t h d", h=1) + + scale_factor = self.head_dim ** -0.5 + attn_output = flash_attn_func( + q, k, v, softmax_scale=scale_factor, causal=True, + ) + + attn_output = einops.rearrange(attn_output, "b t h d -> b t (h d)") + attn_output = self.c_proj(attn_output) + return attn_output, None + + if torch.cuda.get_device_capability() < (8, 0): + model.force_low_gpu_mem_mode = True + logging.warning("Flash attention is not supported on gpus with cuda capability < 8") + return + + logging.warning("Applying flash attention to the model") + for block in model.transformer.h: + block.attn.forward = _forward.__get__(block.attn, type(block.attn)) diff --git a/self_hosting_machinery/finetune/modelling/loss.py b/self_hosting_machinery/finetune/modelling/loss.py new file mode 100644 index 00000000..0e1acd76 --- /dev/null +++ b/self_hosting_machinery/finetune/modelling/loss.py @@ -0,0 +1,76 @@ +from collections import deque +from typing import Optional, List + +import torch +import torch.nn.functional as F + +__all__ = ['masked_loss'] + +unmasked_avg_buf = None + + +def masked_loss( + logits: torch.Tensor, + labels: torch.Tensor, + mask: Optional[torch.Tensor] = None, + *, + average_elements: int, + enc: 'Encoding', + debug_dump: Optional[List[str]] = None +) -> torch.Tensor: + def _average(one_d_tensor: torch.Tensor) -> torch.Tensor: + global unmasked_avg_buf + if unmasked_avg_buf is None: + unmasked_avg_buf = deque(maxlen=average_elements) + if torch.is_grad_enabled(): + for x in one_d_tensor: + unmasked_avg_buf.append(float(x)) + return sum(unmasked_avg_buf) / len(unmasked_avg_buf) + else: + return one_d_tensor.to(torch.float32).mean().item() + + _mb, _T = labels.shape + mb, T, U = logits.shape + assert _T == T + assert _mb == mb + + ce = F.cross_entropy( + logits.reshape(mb * T, U), + labels.reshape(mb * T), + reduction="none" + ).reshape(mb, T) + avg_mask_sum = _average(mask.sum(dim=1)) + loss_ce = ((ce * mask).sum(dim=1) / avg_mask_sum).mean() + + if debug_dump is not None: + import termcolor + def token_str(x, cond, color): + t = "\"" + enc.decode([x]).replace("\n", "\\n") + "\"" + if cond: + return termcolor.colored(t, color) + else: + return t + + with torch.no_grad(): + b = 0 + for ti in range(T): + if b == -1: + continue + if ti & 15 == 0: + debug_dump.append("-----") + largest_logit_n = logits[b, ti].argmax().item() + debug_dump.append(" ".join([ + "%04i" % (ti,), + "ce=%5.2f" % ce[b, ti].item(), + "label=%-20s" % token_str(labels[b, ti].item(), mask[b, ti].item(), "green"), + "mask=%i" % mask[b, ti].item(), + "largest_logit=%05i" % largest_logit_n, + "modelthinks=%-10s" % token_str(largest_logit_n, + (mask[b, ti].item() and labels[b, ti].item() != largest_logit_n), + "red"), + ])) + debug_dump.append("-- (ce * mask).sum(dim=1) = %s" % (ce * mask).sum(dim=1)) + debug_dump.append("-- avg_mask_sum = %s" % avg_mask_sum) + debug_dump.append("-- this example loss_ce = %5.3f" % loss_ce.item()) + + return loss_ce diff --git a/refact_data_pipeline/finetune/sa.py b/self_hosting_machinery/finetune/modelling/triton_flash_sa.py similarity index 99% rename from refact_data_pipeline/finetune/sa.py rename to self_hosting_machinery/finetune/modelling/triton_flash_sa.py index 642e5497..aa0091b9 100644 --- a/refact_data_pipeline/finetune/sa.py +++ b/self_hosting_machinery/finetune/modelling/triton_flash_sa.py @@ -1,4 +1,5 @@ import functools +import logging import math import torch as th @@ -616,7 +617,10 @@ def _forward( return attn_output, None if th.cuda.get_device_capability() < (8, 0): + model.force_low_gpu_mem_mode = True + logging.warning("Triton flash attention is not supported on gpus with cuda capability < 8") return + logging.warning("Applying triton flash attention to the model") for block in model.transformer.h: block.attn.forward = _forward.__get__(block.attn, type(block.attn)) diff --git a/self_hosting_machinery/finetune/modelling/utils.py b/self_hosting_machinery/finetune/modelling/utils.py new file mode 100644 index 00000000..0313d202 --- /dev/null +++ b/self_hosting_machinery/finetune/modelling/utils.py @@ -0,0 +1,16 @@ +from typing import List, Tuple + +from self_hosting_machinery.finetune.configuration import supported_models + + +def map_model_specific_params( + model_name: str, + freeze_exceptions: List[str], + lora_target_modules: List[str] +) -> Tuple[List[str], List[str]]: + assert model_name in supported_models.config + model_config = supported_models.config[model_name] + freeze_exceptions = [model_config["freeze_exceptions_mapping"][e] for e in freeze_exceptions] + lora_target_modules_mapping = [m for modules in lora_target_modules + for m in model_config["lora_target_modules_mapping"][modules]] + return list(set(freeze_exceptions)), list(set(lora_target_modules_mapping)) diff --git a/self_hosting_machinery/finetune/scripts/__init__.py b/self_hosting_machinery/finetune/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/self_hosting_machinery/finetune/scripts/finetune_filter.py b/self_hosting_machinery/finetune/scripts/finetune_filter.py new file mode 100644 index 00000000..e34a7b12 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/finetune_filter.py @@ -0,0 +1,199 @@ +import copy +import json +import logging +import math +import os +import signal +import textwrap +import time +from typing import Dict, Any + +import torch + +import self_hosting_machinery.finetune.utils.traces as traces +from self_hosting_machinery import env +from self_hosting_machinery.finetune.configuration.finetune_config import base_config +from self_hosting_machinery.finetune.scripts.process_uploaded_files import make_matcher +from self_hosting_machinery.finetune.scripts.script_aux.dataset import create_finetune_filter_dataloader, to_cuda +from self_hosting_machinery.finetune.scripts.script_aux.file_sets_context import FileSetsContext +from self_hosting_machinery.finetune.scripts.script_aux.file_status_context import FilesStatusContext +from self_hosting_machinery.finetune.scripts.script_aux.finetune_filter_status_tracker import \ + FinetuneFilterStatusTracker +from self_hosting_machinery.finetune.scripts.script_aux.model import ModelContext +from self_hosting_machinery.finetune.utils.finetune_utils import (get_finetune_config, get_finetune_filter_config) + + +def _log_everywhere(message): + logging.info(message) + traces.log(message) + + +def force_include_exclude_filter( + files_status: FilesStatusContext +): + fcfg = { + "filetypes_finetune": {}, + "filetypes_db": {} + } + if os.path.exists(env.CONFIG_HOW_TO_FILETYPES): + _log_everywhere("Reading %s" % env.CONFIG_HOW_TO_FILETYPES) + with open(env.CONFIG_HOW_TO_FILETYPES, "r") as f: + fcfg.update(**json.load(f)) + + is_force_included, _ = make_matcher(fcfg.get('force_include', '')) + is_force_excluded, _ = make_matcher(fcfg.get('force_exclude', '')) + + for file in files_status.no_status_train_files(): + if is_force_included(file['path']): + files_status.accept_file(file, reason="FORCE_INCLUDED") + elif is_force_excluded(file['path']): + files_status.reject_file(file, reason="FORCE_REJECTED") + + +@torch.inference_mode() +def loss_based_filter( + model_context: ModelContext, + files_status_context: FilesStatusContext, + status_tracker: FinetuneFilterStatusTracker, + *, + filter_loss_threshold +): + def _get_file_loss(file) -> float: + file_losses = [] + ds = create_finetune_filter_dataloader( + file=file, + dataset_options=f"n_ctx={model_context.finetune_cfg['model_info']['ctx_size'] + 1}," + "quit_on_epoch=1,pack_single=1,pack_complete=0", + encoding=model_context.encoding + ) + for data in map(to_cuda, ds): + logits = model_context.forward(input=data['input']) + loss = model_context.loss( + logits=logits.to(torch.float32), + labels=data['labels'], + mask=data['mask'], + ).item() + if not (math.isnan(loss) or math.isinf(loss)): + file_losses.append(loss) + + if len(file_losses) == 0: + raise Exception("small file") + + return sum(file_losses) / len(file_losses) + + model_context.eval() + all_losses = [] + train_files = files_status_context.no_status_train_files() + with status_tracker(total_steps=len(train_files)) as stats_tracker: + for file in train_files: + try: + file_loss = _get_file_loss(file) + except Exception as e: + files_status_context.reject_file(file, reason=str(e)) + continue + + if file_loss > filter_loss_threshold: + files_status_context.reject_file(file, reason=f"loss {file_loss:.3f}") + else: + files_status_context.accept_file(file, reason=f"loss {file_loss:.3f}") + all_losses.append(file_loss) + + stats_tracker.step() + status_tracker.add_stats(avg_loss=sum(all_losses) / (len(all_losses) + 0.001)) + + +def finetune_filter( + status_tracker: FinetuneFilterStatusTracker, + dataset_context: FileSetsContext, + finetune_cfg: Dict[str, Any], + finetune_filter_cfg: Dict[str, Any], +): + _log_everywhere("Loading files statuses...") + file_status_context = FilesStatusContext( + train_files=dataset_context.train_files, + test_files=dataset_context.test_files, + status_tracker=status_tracker + ) + + _log_everywhere("Loading model...") + finetune_cfg['model_info']['lora']['lora_dropout'] = 0.0 + finetune_cfg['model_info']['lora']['lora_init_scale'] = 1e-5 + finetune_cfg['model_info']['loss_average_elements'] = 1 + model_context = ModelContext( + finetune_cfg=finetune_cfg, + ) + + _log_everywhere("Running force include/exclude filter...") + force_include_exclude_filter( + files_status=file_status_context + ) + _log_everywhere("Running perplexity based filter...") + loss_based_filter( + model_context=model_context, + files_status_context=file_status_context, + status_tracker=status_tracker, + filter_loss_threshold=finetune_filter_cfg['filter_loss_threshold'] + ) + + _log_everywhere("Dumping filtered results...") + dataset_context.dump_filtered( + files=file_status_context.accepted_train_files(), + ) + + +def main(models_db: Dict[str, Any]): + _log_everywhere("Loading status tracker...") + status_tracker = FinetuneFilterStatusTracker() + + def catch_sigusr1(signum, frame): + _log_everywhere("catched SIGUSR1, interrupted") + status_tracker.update_status("interrupted", error_message="catched SIGUSR1, interrupted") + exit(99) + + signal.signal(signal.SIGUSR1, catch_sigusr1) + + _log_everywhere("Loading finetune configs...") + finetune_filter_cfg = get_finetune_filter_config(logger=traces.log) + model_name = get_finetune_config(models_db, logger=traces.log)["model_name"] + finetune_cfg = copy.deepcopy(base_config(model_name, models_db)) + + _log_everywhere("Loading file sets context...") + file_sets_context = FileSetsContext( + autoselect_test_files_num=finetune_filter_cfg.get("autoselect_test_files_num", 3) + ) + if file_sets_context.is_up_to_date(): + logging.info("Train set filtering: nothing changed since last time, quit") + return + + traces.log(textwrap.fill( + f"This filter calculates perplexity for each file and filters out " + f"files with perplexity larger than {finetune_filter_cfg['filter_loss_threshold']:.3f}.\n" + f"Those files likely don't have meaningful content to train on", width=100 + )) + try: + status_tracker.update_status("starting") + finetune_filter( + status_tracker=status_tracker, + dataset_context=file_sets_context, + finetune_cfg=finetune_cfg, + finetune_filter_cfg=finetune_filter_cfg, + ) + status_tracker.update_status("finished") + + # finetune_sequence relies on exit code to continue or stop + except (SystemExit, KeyboardInterrupt): + # caught sigusr1, interrupt by watchdog or by user + # this has to be there, even if catch_sigusr1() already called exit with 99, otherwise exit code is zero + exit(99) + except Exception as e: + _log_everywhere(f"Finetune gpu filter is failed\nException: {e}") + status_tracker.update_status("failed", error_message=str(e) or str(type(e))) + raise e + + +if __name__ == "__main__": + from known_models_db.refact_known_models import models_mini_db + + task_name = os.environ.get("LORA_LOGDIR", "") or time.strftime("lora-%Y%m%d-%H%M%S") + traces.configure(task_dir="loras", task_name=task_name, work_dir=env.PERMDIR) + main(models_mini_db) diff --git a/refact_data_pipeline/finetune/finetune_sequence.py b/self_hosting_machinery/finetune/scripts/finetune_sequence.py similarity index 67% rename from refact_data_pipeline/finetune/finetune_sequence.py rename to self_hosting_machinery/finetune/scripts/finetune_sequence.py index 98e25aa7..926534db 100644 --- a/refact_data_pipeline/finetune/finetune_sequence.py +++ b/self_hosting_machinery/finetune/scripts/finetune_sequence.py @@ -22,10 +22,10 @@ def catch_sigusr1(signum, frame): else: os.environ["LORA_LOGDIR"] = "NO_LOGS" try: - subprocess.check_call([sys.executable, "-m", "refact_data_pipeline.finetune.process_uploaded_files"]) - subprocess.check_call([sys.executable, "-m", "refact_data_pipeline.finetune.finetune_filter"]) + subprocess.check_call([sys.executable, "-m", "self_hosting_machinery.finetune.scripts.process_uploaded_files"]) + subprocess.check_call([sys.executable, "-m", "self_hosting_machinery.finetune.scripts.finetune_filter"]) if not filter_only: - subprocess.check_call([sys.executable, "-m", "refact_data_pipeline.finetune.finetune_train"]) + subprocess.check_call([sys.executable, "-m", "self_hosting_machinery.finetune.scripts.finetune_train"]) except subprocess.CalledProcessError as e: print("finetune_sequence: %s" % e) sys.exit(1) diff --git a/self_hosting_machinery/finetune/scripts/finetune_train.py b/self_hosting_machinery/finetune/scripts/finetune_train.py new file mode 100644 index 00000000..09abd4d7 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/finetune_train.py @@ -0,0 +1,253 @@ +import copy +import json +import logging +import multiprocessing +import os +import signal +import time +from pathlib import Path +from typing import Dict, Any, Iterable, Tuple, Optional + +import torch as th + +from self_hosting_machinery import env +from self_hosting_machinery.finetune.configuration.finetune_config import base_config, ConfigBuilder +from self_hosting_machinery.finetune.scripts.script_aux.dataset import ( + create_train_dataloader, create_test_dataloader, get_ds_len_per_epoch, to_cuda +) +from self_hosting_machinery.finetune.scripts.script_aux.early_stopper import EarlyStopper +from self_hosting_machinery.finetune.scripts.script_aux.finetune_status_tracker import FinetuneStatusTracker +from self_hosting_machinery.finetune.scripts.script_aux.model import ModelContext +from self_hosting_machinery.finetune.utils import traces +from self_hosting_machinery.finetune.utils.finetune_utils import get_finetune_config + + +def _log_everywhere(message): + logging.info(message) + traces.log(message) + + +def _build_finetune_config_by_heuristics(models_db: Dict[str, Any]) -> Dict[str, Any]: + with open(env.CONFIG_FINETUNE_FILTER_STAT, 'r') as f: + initial_loss = json.load(f)["avg_loss"] + + _log_everywhere("Calculating finetune optimal parameters") + user_cfg = get_finetune_config(models_db, logger=traces.log) + cfg_builder = ConfigBuilder(base_config(user_cfg['model_name'], models_db)) + if user_cfg['use_heuristics']: + _log_everywhere("Retrieving dataset length per epoch, it may take a while...") + ds_len = get_ds_len_per_epoch(user_cfg['model_name'], cfg_builder) + traces.log(f"Dataset length per epoch = {ds_len}") + (cfg_builder + .set_lora_quality_by_heuristics(ds_len=ds_len, initial_loss=initial_loss) + .set_schedule_by_heuristics(ds_len=ds_len) + .set_low_gpu_mem_mode_by_heuristics()) + else: + (cfg_builder + .set_train_steps(user_cfg['train_steps']) + .set_lr_decay_steps(user_cfg['lr_decay_steps']) + .set_lora_r(user_cfg['lora_r']) + .set_lora_alpha(user_cfg['lora_alpha']) + .set_lora_init_scale(user_cfg['lora_init_scale']) + .set_lora_dropout(user_cfg['lora_dropout']) + .set_low_gpu_mem_mode(user_cfg['low_gpu_mem_mode'])) + (cfg_builder + .set_lr(user_cfg['lr']) + .set_batch_size(user_cfg['batch_size']) + .set_warmup_steps(user_cfg['warmup_num_steps']) + .set_limit_time_seconds(user_cfg['limit_time_seconds']) + .set_weight_decay(user_cfg['weight_decay'])) + + traces.log(f'Freeze exceptions: {cfg_builder.cfg["model_info"]["freeze_exceptions"]}') + for k, v in cfg_builder.cfg["model_info"]["lora"].items(): + traces.log(f'Lora config: {k:>20} {v}') + + with open(os.path.join(traces.context().path, "config.json"), "w") as f: + json.dump(cfg_builder.cfg, f, indent=4) + + assert cfg_builder.cfg['train_iters'] % cfg_builder.cfg['test_every'] == 0 + assert cfg_builder.cfg['save_every'] % cfg_builder.cfg['test_every'] == 0 + + return cfg_builder.cfg + + +def _train_iteration( + data: Dict[str, Any], + iter_n: int, + model_context: ModelContext, + finetune_cfg: Dict[str, Any], +) -> Tuple[float, int]: + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + if finetune_cfg['debug']: + data_path = Path(traces.context().path) / ('debug_data/iter%04d' % iter_n) + data_path.mkdir(exist_ok=True, parents=True) + + losses, tokens_n = [], 0 + for b0 in range(0, finetune_cfg["train_batch_size"], finetune_cfg["micro_batch_size"]): + input = data['input'][b0:b0 + finetune_cfg["micro_batch_size"]].contiguous() + logits = model_context.forward(input=input) + loss = model_context.loss( + logits=logits, + labels=data['labels'][b0:b0 + finetune_cfg["micro_batch_size"]].contiguous(), + mask=data['mask'][b0:b0 + finetune_cfg["micro_batch_size"]].contiguous(), + ) + model_context.backward(loss) + model_context.step() + tokens_n += (input.shape[0] * input.shape[1]) * world_size + losses.append(loss.item()) + + if finetune_cfg['debug']: + with open(data_path / ('%d_%0.3f.txt' % (b0, loss.item())), 'w') as f: + f.write(model_context.encoding.decode(input[0].cpu().numpy())) + + return sum(losses) / len(losses), tokens_n + + +def _test_iteration( + test_ds: Iterable[Dict[str, Any]], + iter_n: int, + model_context: ModelContext, + finetune_cfg: Dict[str, Any], +) -> float: + if finetune_cfg["test_every"] > 0 and iter_n % finetune_cfg["test_every"] == 0: + model_context.eval() + with th.inference_mode(): + losses = [] + for batch in map(to_cuda, test_ds): + logits = model_context.forward(input=batch['input']) + loss = model_context.loss( + logits=logits, + labels=batch['labels'], + mask=batch['mask'], + ) + losses.append(loss.item()) + + model_context.train() + return sum(losses) / len(losses) + + +def loop( + finetune_cfg: Dict[str, Any], + model_context: ModelContext, + status_tracker: FinetuneStatusTracker +): + def _save_checkpoint(iter_n: int, loss: float, force: bool = False): + if force or (iter_n != 0 and iter_n % finetune_cfg['save_every'] == 0): + tag = f"iter{iter_n:04d}-testloss{loss:.3f}" + traces.log("saving checkpoint %s" % tag) + model_context.save_model_state(save_path=save_path, tag=tag) + + save_path = os.path.join(traces.context().path, "checkpoints") + model_context.train() + train_iters = finetune_cfg['train_iters'] + overall_tokens_n = 0 + t0 = time.time() + + train_ds = create_train_dataloader( + model_name=model_context.model_name, + encoding=model_context.encoding, + num_workers=max(multiprocessing.cpu_count() // 2, 1), + batch_size=finetune_cfg['train_batch_size'], + ctx_size=finetune_cfg['model_info']['ctx_size'] + ) + train_ds_iter = iter(train_ds) + test_ds = create_test_dataloader( + model_name=model_context.model_name, + encoding=model_context.encoding, + ctx_size=finetune_cfg['model_info']['ctx_size'] + ) + test_ds = list(map(to_cuda, test_ds)) + + early_stop = EarlyStopper(patience=int(train_iters * 0.2)) + with status_tracker(total_steps=train_iters) as stats_tracker: + for iter_n in range(train_iters): + data = to_cuda(next(train_ds_iter)) + traces.log( + f"iter {iter_n}/{finetune_cfg['train_iters']} tokens {overall_tokens_n / 1e9:0.3f} " + f"input={traces.p(data['input'])} mask={traces.p(data['mask'])} " + f"({data['mask'].sum()}/{data['mask'].numel()})" + ) + train_loss, tokens_n = _train_iteration( + data=data, + iter_n=iter_n, + model_context=model_context, + finetune_cfg=finetune_cfg + ) + overall_tokens_n += tokens_n + + test_loss = _test_iteration( + test_ds=test_ds, + iter_n=iter_n, + model_context=model_context, + finetune_cfg=finetune_cfg + ) + + stats_tracker.step( + loss=train_loss, + test_loss=test_loss, + **{f'ds/{k}': v for k, v in data.get("stats", dict()).items()}, + **model_context.train_information(), + gtokens=overall_tokens_n / 1e9, + tokens_num=overall_tokens_n, + time_elapsed=time.time() - t0, + ) + + if early_stop(test_loss): + traces.log(f"Stopping the training due to " + f"test loss was above minimum {early_stop.counter} times") + _save_checkpoint(force=True, iter_n=iter_n, loss=test_loss) + break + else: + _save_checkpoint(force=False, iter_n=iter_n, loss=test_loss) + + +def main(models_db: Dict[str, Any]): + _log_everywhere("Loading status tracker...") + status_tracker = FinetuneStatusTracker() + + def catch_sigusr1(signum, frame): + _log_everywhere("catched SIGUSR1, interrupted") + status_tracker.update_status("interrupted", error_message="catched SIGUSR1, interrupted") + exit(99) + + signal.signal(signal.SIGUSR1, catch_sigusr1) + + try: + status_tracker.update_status("working") + _log_everywhere("Loading finetune configs...") + finetune_cfg = copy.deepcopy(_build_finetune_config_by_heuristics(models_db)) + + _log_everywhere(f"Building the model...") + model_context = ModelContext( + finetune_cfg=finetune_cfg, + use_deepspeed=True + ) + + _log_everywhere(f"Starting finetune at {traces.context().path}\n\n") + loop( + finetune_cfg=finetune_cfg, + model_context=model_context, + status_tracker=status_tracker + ) + + _log_everywhere("finished finetune at %s" % traces.context().path) + status_tracker.update_status("finished") + + # finetune_sequence relies on exit code to continue or stop + except (SystemExit, KeyboardInterrupt): + # caught sigusr1, interrupt by watchdog or by user + # this has to be there, even if catch_sigusr1() already called exit with 99, otherwise exit code is zero + exit(99) + except Exception as e: + _log_everywhere(f"Finetune is failed\nException: {e}") + status_tracker.update_status("failed", error_message=str(e) or str(type(e))) + raise e + + +if __name__ == "__main__": + from known_models_db.refact_known_models import models_mini_db + + YMD_hms = os.environ.get("LORA_LOGDIR", "") or time.strftime("lora-%Y%m%d-%H%M%S") + traces.configure(task_dir="loras", task_name=YMD_hms, work_dir=env.PERMDIR) + main(models_mini_db) diff --git a/refact_data_pipeline/finetune/process_uploaded_files.py b/self_hosting_machinery/finetune/scripts/process_uploaded_files.py similarity index 99% rename from refact_data_pipeline/finetune/process_uploaded_files.py rename to self_hosting_machinery/finetune/scripts/process_uploaded_files.py index 6fdd2ce6..8b843ae5 100644 --- a/refact_data_pipeline/finetune/process_uploaded_files.py +++ b/self_hosting_machinery/finetune/scripts/process_uploaded_files.py @@ -12,7 +12,7 @@ from fnmatch import fnmatch from self_hosting_machinery import env -import refact_data_pipeline.finetune.traces as traces +import self_hosting_machinery.finetune.utils.traces as traces from typing import List, Dict, Any, Iterable, Tuple @@ -22,7 +22,7 @@ EXE = "smc-linguist" # Rely on PATH CWD = os.getcwd() -GIT_EXE = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../git_command.exp') +GIT_EXE = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../refact_data_pipeline/git_command.exp') TRUSTED_LANGUAGES = { 'Assembly', 'Batchfile', 'C', 'C#', 'C++', 'CMake', 'CSS', 'Cuda', 'Dockerfile', 'Fortran', diff --git a/self_hosting_machinery/finetune/scripts/script_aux/__init__.py b/self_hosting_machinery/finetune/scripts/script_aux/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/self_hosting_machinery/finetune/scripts/script_aux/dataset.py b/self_hosting_machinery/finetune/scripts/script_aux/dataset.py new file mode 100644 index 00000000..0d91e454 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/dataset.py @@ -0,0 +1,163 @@ +import multiprocessing +import os +from typing import Any, Dict + +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from refact_data_pipeline import finetune_datasource +from refact_data_pipeline.datautils import collate_fn, data_parallel_split_and_collate_fn +from self_hosting_machinery.finetune.configuration import supported_models +from self_hosting_machinery.scripts.env import TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH + +__all__ = [ + "create_train_dataloader", + "create_test_dataloader", + "create_finetune_filter_dataloader", + "get_ds_len_per_epoch", + "to_cuda", +] + + +def setup_encoding( + model_name: str, + weights_path: str, + repo_id: str +) -> AutoTokenizer: + model_config = supported_models.config[model_name] + assert "tokenizer" in model_config, "Provided tokenizer is no longer supported" + encoding = AutoTokenizer.from_pretrained( + repo_id, cache_dir=weights_path, + trust_remote_code=True + ) + encoding.decode_utf8 = lambda x, *args, **kwargs: encoding.decode(x) + encoding.EOT = model_config["tokenizer"]["eot_idx"] + encoding.DIAMOND = model_config["tokenizer"]["padding_idx"] + encoding.PREFIX = model_config["tokenizer"]["fim_prefix"] + encoding.INFILL = model_config["tokenizer"]["fim_middle"] + encoding.SUFFIX = model_config["tokenizer"]["fim_suffix"] + encoding.ESCAPE = model_config["tokenizer"]["escape"] + return encoding + + +def get_ds_len_per_epoch(model_name, cfg_builder): + encoding = setup_encoding( + model_name=model_name, + weights_path=cfg_builder.cfg['model_info']['weight_path'], + repo_id=cfg_builder.cfg['model_info']['repo_id'] + ) + ds = create_train_dataloader( + model_name=model_name, + encoding=encoding, + num_workers=multiprocessing.cpu_count(), + batch_size=cfg_builder.cfg['micro_batch_size'], + ctx_size=cfg_builder.cfg['model_info']['ctx_size'], + extra_options="quit_on_epoch=1" + ) + return sum(1 for _ in ds) * int(os.environ.get('WORLD_SIZE', 1)) + + +def create_train_dataloader( + model_name: str, + encoding: 'Encoding', + ctx_size: int, + batch_size: int, + num_workers: int, + extra_options: str = "", +) -> DataLoader: + world_size = int(os.environ.get('WORLD_SIZE', 1)) + model_config = supported_models.config[model_name] + ds_name = model_config["train_ds_pipeline"]["ds_name"] + ds_opts = model_config["train_ds_pipeline"]["ds_opts"].format( + n_ctx=ctx_size + 1 + ) + if extra_options: + ds_opts = f"{ds_opts},{extra_options}" + + dataset_cls = getattr(finetune_datasource, ds_name) + dataset = getattr(finetune_datasource, ds_name).from_a_jsonl( + cls=dataset_cls, + jsonl_path=TRAIN_FILTERED_FILEPATH, + dataset_options=ds_opts, + encoding=encoding, + ) + if dataset.files_len == 0: + raise RuntimeError("No train files provided") + + return DataLoader( + dataset, + batch_size=batch_size * world_size, + num_workers=num_workers, + shuffle=False, + drop_last=True, + pin_memory=False, + collate_fn=data_parallel_split_and_collate_fn + ) + + +def create_test_dataloader( + model_name: str, + encoding: 'Encoding', + ctx_size: int, + extra_options: str = "", +) -> DataLoader: + model_config = supported_models.config[model_name] + ds_name = model_config["test_ds_pipeline"]["ds_name"] + ds_opts = model_config["test_ds_pipeline"]["ds_opts"].format( + n_ctx=ctx_size + 1 + ) + if extra_options: + ds_opts = f"{ds_opts},{extra_options}" + + dataset_cls = getattr(finetune_datasource, ds_name) + dataset = getattr(finetune_datasource, ds_name).from_a_jsonl( + cls=dataset_cls, + jsonl_path=TEST_FILTERED_FILEPATH, + dataset_options=ds_opts, + encoding=encoding, + ) + if dataset.files_len == 0: + raise RuntimeError("No test files provided") + + return DataLoader( + dataset, + batch_size=1, + num_workers=0, + shuffle=False, + drop_last=False, + pin_memory=False, + collate_fn=collate_fn + ) + + +def create_finetune_filter_dataloader( + file: Dict[str, Any], + dataset_options: str, + encoding: str, +) -> DataLoader: + dataset = finetune_datasource.RefactDataset.from_a_single_file( + cls=finetune_datasource.RefactPlainCodeDataset, + file=file, + dataset_options=dataset_options, + encoding=encoding + ) + if dataset.files_len == 0: + raise RuntimeError("No files for filtering are provided") + + return DataLoader( + dataset, + batch_size=1, + num_workers=0, + shuffle=False, + drop_last=False, + pin_memory=False, + collate_fn=collate_fn + ) + + +def to_cuda(batch: Dict[str, Any]) -> Dict[str, Any]: + return { + k: (v.cuda() if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } diff --git a/self_hosting_machinery/finetune/scripts/script_aux/early_stopper.py b/self_hosting_machinery/finetune/scripts/script_aux/early_stopper.py new file mode 100644 index 00000000..c72807ae --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/early_stopper.py @@ -0,0 +1,19 @@ +__all__ = ['EarlyStopper'] + + +class EarlyStopper: + def __init__(self, patience: int = 1, min_delta: float = 0): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float('inf') + + def __call__(self, validation_loss: float): + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/self_hosting_machinery/finetune/scripts/script_aux/file_sets_context.py b/self_hosting_machinery/finetune/scripts/script_aux/file_sets_context.py new file mode 100644 index 00000000..dc27db51 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/file_sets_context.py @@ -0,0 +1,96 @@ +import random +from pathlib import Path +from typing import List, Dict, Any + +import jsonlines + +from self_hosting_machinery.finetune.utils import traces +from self_hosting_machinery.scripts import env + +from self_hosting_machinery.scripts.env import (TRAIN_UNFILTERED_FILEPATH, TEST_UNFILTERED_FILEPATH, + TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH) + +__all__ = ['FileSetsContext'] + + +class FileSetsContext: + TRAIN_FILES_MIN_NUMBER_WITH_TEST_SET = 4 + TRAIN_FILES_MIN_NUMBER_WITHOUT_TEST_SET = 7 + TEST_FILES_COUNT_WARNING = 64 + + def __init__(self, autoselect_test_files_num: int): + self._check_prerequisites() + self.autoselect_test_files_num = autoselect_test_files_num + self.train_files: List[Dict[str, Any]] = list(jsonlines.open(TRAIN_UNFILTERED_FILEPATH)) + self.test_files: List[Dict[str, Any]] = list(jsonlines.open(TEST_UNFILTERED_FILEPATH)) + + def _check_prerequisites(self): + if not Path(TRAIN_UNFILTERED_FILEPATH).exists(): + raise RuntimeError("No train files have been provided") + + train_files = list(jsonlines.open(TRAIN_UNFILTERED_FILEPATH)) + test_files = list(jsonlines.open(TEST_UNFILTERED_FILEPATH)) + train_min_number = ( + self.TRAIN_FILES_MIN_NUMBER_WITH_TEST_SET if len(test_files) > 0 else + self.TRAIN_FILES_MIN_NUMBER_WITHOUT_TEST_SET + ) + if len(train_files) < train_min_number: + raise RuntimeError(f"Provided train set is too small ({len(train_files)} files)\n" + f"It should contain at least {train_min_number} files") + + if len(test_files) > self.TEST_FILES_COUNT_WARNING: + traces.log(f"Manually selected test set contains {len(test_files)} files. " + f"It could heavily slow down the training process on the next stage") + + def is_up_to_date(self) -> bool: + unfiltered_train, filtered_train = ( + Path(TRAIN_UNFILTERED_FILEPATH), Path(TRAIN_FILTERED_FILEPATH) + ) + unfiltered_test, filtered_test = ( + Path(TEST_UNFILTERED_FILEPATH), Path(TEST_FILTERED_FILEPATH) + ) + how_to_filter = Path(env.CONFIG_HOW_TO_FILTER) + how_to_filetypes = Path(env.CONFIG_HOW_TO_FILETYPES) + + try: + has_updates = [ + unfiltered_train.lstat().st_mtime > filtered_train.lstat().st_mtime, + unfiltered_test.lstat().st_mtime > filtered_test.lstat().st_mtime, + ] + if how_to_filter.exists(): + has_updates.append(how_to_filter.lstat().st_mtime > filtered_train.lstat().st_mtime) + if how_to_filetypes.exists(): + has_updates.append(how_to_filetypes.lstat().st_mtime > filtered_train.lstat().st_mtime) + except OSError: + return False + return not any(has_updates) + + def dump_filtered( + self, + files: List[Dict[str, Any]] + ): + def _dump(files, filename): + with jsonlines.open(filename, "w") as f: + for file in files: + f.write(file) + + if len(self.test_files) == 0: + test_files_count = min(self.autoselect_test_files_num, len(self.train_files) // 2) + if test_files_count == 0: + raise RuntimeError( + "It is too little files to choose a test set from. " + "It's strongly recommended to choose a test set manually to be able to prevent overfitting" + ) + else: + random.shuffle(files) + test_files = files[:test_files_count] + train_files = files[test_files_count:] + else: + train_files = files + test_files = self.test_files + + _dump(train_files, TRAIN_FILTERED_FILEPATH) + _dump(test_files, TEST_FILTERED_FILEPATH) + traces.log("-" * 40 + "TEST SET" + "-" * 40) + for file in test_files: + traces.log(file["path"]) diff --git a/self_hosting_machinery/finetune/scripts/script_aux/file_status_context.py b/self_hosting_machinery/finetune/scripts/script_aux/file_status_context.py new file mode 100644 index 00000000..bcb8b5d1 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/file_status_context.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Any, Dict, List + +from self_hosting_machinery.finetune.scripts.script_aux.finetune_filter_status_tracker import \ + FinetuneFilterStatusTracker +from self_hosting_machinery.finetune.utils import traces +from self_hosting_machinery.finetune.utils.finetune_utils import get_file_digest +from self_hosting_machinery.scripts import env + +__all__ = ['FilesStatusContext', 'FileStatus'] + +DIR_UNPACKED = Path(env.DIR_UNPACKED) + + +@dataclass +class FileStatus: + path: Path + info: Dict[str, Any] + is_train: bool + status: Optional[str] = None + reason: Optional[str] = None + + def hash(self) -> str: + assert self.path.exists(), f"File {self.path} doesn't exist, try to rescan your files" + return get_file_digest(self.path) + + +class FilesStatusContext: + def __init__( + self, + train_files: List[Dict[str, Any]], + test_files: List[Dict[str, Any]], + status_tracker: FinetuneFilterStatusTracker + ): + self.file_statuses: Dict[str, FileStatus] = { + info["path"]: FileStatus(path=Path(DIR_UNPACKED / info["path"]), info=info, is_train=True) + for info in train_files + } + self.file_statuses.update({ + info["path"]: FileStatus(path=Path(DIR_UNPACKED / info["path"]), info=info, is_train=False) + for info in test_files + }) + self.log_files_accepted_ftf = Path(env.LOG_FILES_ACCEPTED_FTF) + self.log_files_rejected_ftf = Path(env.LOG_FILES_REJECTED_FTF) + with self.log_files_accepted_ftf.open('w') as f: + f.write("") + with self.log_files_rejected_ftf.open('w') as f: + f.write("") + self._global_stats = status_tracker + self._check_prerequisites() + + def _check_prerequisites(self): + train_hashes_dict = { + f.hash(): f for f in self.file_statuses.values() if f.is_train + } + train_hashes = set(train_hashes_dict.keys()) + test_hashes = set( + f.hash() for f in self.file_statuses.values() if not f.is_train + ) + inters = train_hashes.intersection(test_hashes) + if len(inters) > 0: + paths = [train_hashes_dict[h].path for h in inters] + raise RuntimeError(f"Provided similar files in train and test set: {paths}") + + def _change_file_status(self, file: Dict[str, Any], status: str, reason: str, log_file: Path): + assert file["path"] in self.file_statuses + file_status = self.file_statuses[file["path"]] + file_status.status = status + file_status.reason = reason + try: + with open(log_file, "a", encoding="utf-8") as f: + f.write(f"{reason} {file['path']}\n") + except Exception as e: + traces.log(f"Couldn't fill the log file {log_file}: {e}") + raise e + + def accept_file(self, file: Dict[str, Any], reason: str): + self._change_file_status(file, "accepted", reason, self.log_files_accepted_ftf) + self._global_stats.set_accepted_num(self.accepted_files_num) + + def reject_file(self, file: Dict[str, Any], reason: str): + traces.log(f"REJECTED FILTER {file['path']:<100} {reason}") + self._change_file_status(file, "rejected", reason, self.log_files_rejected_ftf) + self._global_stats.set_rejected_num(self.rejected_files_num) + + def no_status_train_files(self) -> List[Dict[str, Any]]: + """ + :return: List of files that are train and not have status + """ + return [ + f.info for f in self.file_statuses.values() + if f.status is None and f.is_train + ] + + def no_status_test_files(self) -> List[Dict[str, Any]]: + """ + :return: List of files that are test and not have status + """ + return [ + f.info for f in self.file_statuses.values() + if f.status is None and not f.is_train + ] + + def accepted_train_files(self) -> List[Dict[str, Any]]: + """ + :return: List of train files with accepted status + """ + return [ + f.info for f in self.file_statuses.values() + if f.status == "accepted" and f.is_train + ] + + @property + def accepted_files_num(self) -> int: + return sum(s.status == "accepted" for s in self.file_statuses.values()) + + @property + def rejected_files_num(self) -> int: + return sum(s.status == "rejected" for s in self.file_statuses.values()) diff --git a/self_hosting_machinery/finetune/scripts/script_aux/finetune_filter_status_tracker.py b/self_hosting_machinery/finetune/scripts/script_aux/finetune_filter_status_tracker.py new file mode 100644 index 00000000..e810e7e4 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/finetune_filter_status_tracker.py @@ -0,0 +1,78 @@ +import json +import os +import time +from typing import Dict, Any, Optional + +from self_hosting_machinery.finetune.utils.eta import EtaTracker +from self_hosting_machinery.finetune.utils.finetune_utils import get_finetune_filter_stat +from self_hosting_machinery.scripts import env + +__all__ = ['FinetuneFilterStatusTracker'] + + +class FinetuneFilterStatusTracker: + class LoopStatusTracker: + def __init__(self, context, total_steps: int): + self.context: FinetuneFilterStatusTracker = context + self.eta_tracker = EtaTracker(total_steps) + self.iter_n = 0 + self.initial_iter_tp = time.time() + self.last_iter_tp = time.time() + + def step(self): + self.eta_tracker.append(time.time() - self.last_iter_tp) + self.context._stats_dict["eta_minutes"] = int(round(self.eta_tracker.eta() / 60)) + self.context._stats_dict["worked_steps"] = self.iter_n + self.context._stats_dict["worked_minutes"] = int((time.time() - self.initial_iter_tp) / 60) + self.context.dump() + self.iter_n += 1 + self.last_iter_tp = time.time() + + def __init__(self): + self._stats_dict = get_finetune_filter_stat(default=True) + self._tracker_extra_kwargs: Dict[str, Any] = dict() + + def dump(self): + with open(env.CONFIG_FINETUNE_FILTER_STAT + ".tmp", "w") as f: + json.dump(self._stats_dict, f, indent=4) + os.rename(env.CONFIG_FINETUNE_FILTER_STAT + ".tmp", env.CONFIG_FINETUNE_FILTER_STAT) + + def update_status( + self, + status: str, + error_message: Optional[str] = None, + dump: bool = True + ): + env.report_status("filter", status) + self._stats_dict["filtering_status"] = status + if error_message is not None: + assert status in {"failed", "interrupted"} + self._stats_dict["error"] = error_message + if dump: + self.dump() + + def set_accepted_num(self, num: int, dump: bool = True): + self._stats_dict["accepted"] = num + if dump: + self.dump() + + def set_rejected_num(self, num: int, dump: bool = True): + self._stats_dict["rejected"] = num + if dump: + self.dump() + + def __call__(self, **kwargs): + self._tracker_extra_kwargs.clear() + self._tracker_extra_kwargs.update(kwargs) + return self + + def __enter__(self) -> 'FinetuneFilterStatusTracker.LoopStatusTracker': + self.add_stats(**self._tracker_extra_kwargs) + return FinetuneFilterStatusTracker.LoopStatusTracker(context=self, **self._tracker_extra_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def add_stats(self, **kwargs): + self._stats_dict.update(kwargs) + self.dump() diff --git a/self_hosting_machinery/finetune/scripts/script_aux/finetune_status_tracker.py b/self_hosting_machinery/finetune/scripts/script_aux/finetune_status_tracker.py new file mode 100644 index 00000000..1ccbd4f5 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/finetune_status_tracker.py @@ -0,0 +1,113 @@ +import json +import logging +import os +import time +from pathlib import Path +from typing import Dict, Any, Optional + +from self_hosting_machinery.finetune.utils import traces +from self_hosting_machinery.finetune.utils.eta import EtaTracker +from self_hosting_machinery.finetune.utils.traces_plot import AsyncPlotter +from self_hosting_machinery.scripts import env + +__all__ = ['FinetuneStatusTracker'] + + +def get_finetune_status() -> Dict[str, Any]: + return { + "started_ts": time.time(), + "worked_steps": 0, + "worked_minutes": 0, + "status": "starting" + } + + +class FinetuneStatusTracker: + class LoopStatusTracker: + def __init__(self, context, total_steps: int): + self.context: FinetuneStatusTracker = context + self.eta_tracker = EtaTracker(total_steps) + self.plotter = AsyncPlotter( + run_path=Path(traces.context().path), + progress_filename="progress.jsonl", + iters=total_steps + ) + self.iter_n = 0 + self.initial_iter_tp = time.time() + self.last_iter_tp = time.time() + + def step(self, **to_log): + self.eta_tracker.append(time.time() - self.last_iter_tp) + + self.context._stats_dict["eta_minutes"] = int(round(self.eta_tracker.eta() / 60)) + self.context._stats_dict["worked_steps"] = self.iter_n + self.context._stats_dict["worked_minutes"] = int((time.time() - self.initial_iter_tp) / 60) + + if self.context._rank == 0: + traces.progress("iteration", self.iter_n) + traces.progress("eta_minutes", self.context._stats_dict["eta_minutes"]) + traces.progress("worked_steps", self.context._stats_dict["worked_steps"]) + traces.progress("worked_minutes", self.context._stats_dict["worked_minutes"]) + for k, v in to_log.items(): + traces.progress(k, v) + progress = traces.progress_dump(step=self.iter_n) + logging.info(f"finished iteration {self.iter_n}, " + f"train_loss={progress['loss']:.3f}, " + f"test_loss={progress['test_loss']:.3f}") + self.plotter.plot_async() + + self.context.dump() + self.iter_n += 1 + self.last_iter_tp = time.time() + + def __init__(self): + self._stats_dict = get_finetune_status() + self._rank = os.environ.get('RANK', 0) + self._tracker_extra_kwargs: Dict[str, Any] = dict() + self._status_filename = Path(traces.context().path) / "status.json" + + def dump(self): + if self._rank != 0: + return + + traces.touch() + if not traces.context(): + return + with open(self._status_filename.with_suffix(".tmp"), "w") as f: + json.dump(self._stats_dict, f, indent=4) + os.rename(self._status_filename.with_suffix(".tmp"), self._status_filename) + + def update_status( + self, + status: str, + error_message: Optional[str] = None, + dump: bool = True + ): + env.report_status("ftune", status) + self._stats_dict["status"] = status + if error_message is not None: + assert status in {"failed", "interrupted"} + self._stats_dict["error"] = error_message + if dump: + self.dump() + + def set_accepted_num(self, num: int, dump: bool = True): + self._stats_dict["accepted"] = num + if dump: + self.dump() + + def __call__(self, **kwargs): + self._tracker_extra_kwargs.clear() + self._tracker_extra_kwargs.update(kwargs) + return self + + def __enter__(self) -> 'FinetuneStatusTracker.LoopStatusTracker': + self.add_stats(**self._tracker_extra_kwargs) + return FinetuneStatusTracker.LoopStatusTracker(context=self, **self._tracker_extra_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def add_stats(self, **kwargs): + self._stats_dict.update(kwargs) + self.dump() diff --git a/self_hosting_machinery/finetune/scripts/script_aux/model.py b/self_hosting_machinery/finetune/scripts/script_aux/model.py new file mode 100644 index 00000000..ee158ee6 --- /dev/null +++ b/self_hosting_machinery/finetune/scripts/script_aux/model.py @@ -0,0 +1,291 @@ +import importlib +import logging +from functools import partial +from pathlib import Path +from typing import Dict, Any, List, Tuple + +import deepspeed +import torch +from torchinfo import summary +from transformers import AutoTokenizer, AutoModelForCausalLM + +from refact_models.lora import LoraMixin +from self_hosting_machinery.finetune.configuration import supported_models +from self_hosting_machinery.finetune.modelling.loss import masked_loss +from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params +from self_hosting_machinery.finetune.utils import traces +from self_hosting_machinery.finetune.utils.timer import Timer + +__all__ = ['ModelContext'] + + +def _lora_state_dict(model, *args, destination=None, prefix='', keep_vars=False, layer_names): + return { + name: p + for name, p in model.old_state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ).items() + if any(n in name for n in layer_names) + } + + +class ModelContext: + def __init__( + self, + finetune_cfg: Dict[str, Any], + use_deepspeed: bool = False, + debug: bool = False + ): + self.model_name = finetune_cfg["model_name"] + self.finetune_cfg = finetune_cfg + self.model_mappings_config = supported_models.config[self.model_name] + self.low_gpu_mem_hook = None + with Timer(message="/model load {time_ms:.1f}ms"): + self.model = self._make_model( + weights_path=self.finetune_cfg['model_info']['weight_path'], + repo_id=self.finetune_cfg['model_info']['repo_id'], + freeze_exceptions=self.finetune_cfg['model_info']['freeze_exceptions'], + lora_target_modules=self.finetune_cfg['model_info']['lora']['lora_target_modules'], + lora_r=self.finetune_cfg['model_info']['lora']['lora_r'], + lora_alpha=self.finetune_cfg['model_info']['lora']['lora_alpha'], + lora_dropout=self.finetune_cfg['model_info']['lora']['lora_dropout'], + lora_init_scale=self.finetune_cfg['model_info']['lora']['lora_init_scale'], + dtype=(torch.bfloat16 if 'bf16' in self.finetune_cfg and self.finetune_cfg['bf16']['enabled'] + else torch.float16), + init_device="cuda", + device="cuda", + ) + self._set_low_gpu_mode( + self.finetune_cfg['low_gpu_mem_mode'] + or self.model_mappings_config['force_enable_checkpointing'] + ) + self.encoding = self.model.encoding + + if use_deepspeed: + with Timer(message="/deepspeed initialization {time_ms:.1f}ms"): + self.model, _, _, _ = deepspeed.initialize( + config=self.finetune_cfg, + model=self.model, + model_parameters=[p for p in self.model.parameters() if p.requires_grad], + dist_init_required=True + ) + self.use_deepspeed = True + + if debug: + logging.info("1 gpumem_p0 %0.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) + summary(self.model, depth=4, col_names=['num_params', 'params_percent', 'trainable']) + + self.loss_fn = partial( + masked_loss, + average_elements=self.finetune_cfg['model_info']['loss_average_elements'], + enc=self.encoding + ) + + def _make_model( + self, + weights_path: str, + repo_id: str, + *, + freeze_exceptions: List[str], + lora_target_modules: List[str], + lora_r: int, + lora_alpha: int, + lora_dropout: float, + lora_init_scale: float, + dtype: torch.dtype, + init_device: str = "cpu", + device: str = "cuda", + ) -> torch.nn.Module: + model = AutoModelForCausalLM.from_pretrained( + repo_id, + cache_dir=weights_path, + device_map=init_device, + torch_dtype=dtype, + trust_remote_code=True + ) + model.encoding = self._setup_encoding( + weights_path=self.finetune_cfg['model_info']['weight_path'], + repo_id=self.finetune_cfg['model_info']['repo_id'] + ) + freeze_exceptions, lora_target_modules = self._map_model_specific_params( + freeze_exceptions, lora_target_modules + ) + self._apply_model_modifiers( + model + ) + LoraMixin.apply_lora( + model.to(device), + lora_target_modules=lora_target_modules, + lora_r=int(lora_r), + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + lora_init_scale=lora_init_scale + ) + self._freeze_model( + model, + freeze_exceptions=freeze_exceptions + ) + model.old_state_dict = model.state_dict + model.state_dict = partial( + _lora_state_dict.__get__(model, type(model)), + layer_names=freeze_exceptions + ) + return model + + def forward( + self, + input: torch.Tensor + ) -> torch.Tensor: + logits = self.model.forward( + input, + return_dict=False, + output_attentions=False, + output_hidden_states=False + )[0] + return logits + + def loss( + self, + logits: torch.Tensor, + labels: torch.Tensor, + mask: torch.Tensor + ) -> torch.Tensor: + loss = self.loss_fn( + logits=logits, + labels=labels, + mask=mask, + ) + return loss + + def backward( + self, loss: torch.Tensor + ): + assert self.use_deepspeed + try: + self.model.backward(loss) + except torch.cuda.OutOfMemoryError as e: + if self.low_gpu_mem_mode: + raise e + else: + self.model.optimizer.zero_grad() + torch.cuda.empty_cache() + self._set_low_gpu_mode(low_gpu_mode=True) + traces.log("switching to low GPU memory mode") + self.backward(loss) + + def step(self): + assert self.use_deepspeed + self.model.step() + + def train_information(self) -> Dict[str, Any]: + if self.use_deepspeed: + return dict(gpumem_p0=torch.cuda.max_memory_allocated()) + + return dict( + gpumem_p0=torch.cuda.max_memory_allocated(), + lr=self.model.optimizer.param_groups[-1]['lr'], + num_skipped_updates=self.model.skipped_steps, + scale=self.model.optimizer.cur_scale, + ) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def save_model_state( + self, + save_path: str, + tag: str + ): + keys_white_list = { + 'module', 'buffer_names', 'optimizer', 'param_shapes', 'frozen_param_shapes', + 'lr_scheduler', 'data_sampler', 'random_ltd', 'sparse_tensor_module_names', + 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', + 'ds_config', 'ds_version' + } + + self.model.save_checkpoint(save_path, tag=tag) + cp_path = Path(save_path) / tag + model_cps = [p for p in cp_path.iterdir() if 'model_states' in p.name] + _ = [p.unlink() for p in cp_path.iterdir() if 'model_states' not in p.name] + for cp_path in model_cps: + cp = torch.load(str(cp_path), map_location='cpu') + cp = {k: v for k, v in cp.items() if k in keys_white_list} + torch.save(cp, str(cp_path)) + + def _freeze_model( + self, + model: torch.nn.Module, + freeze_exceptions: List[str] + ): + for name, p in model.named_parameters(): + if any([e in name for e in freeze_exceptions]): + p.requires_grad_(True) + else: + p.requires_grad_(False) + + def _setup_encoding( + self, + weights_path: str, + repo_id: str + ) -> AutoTokenizer: + assert "tokenizer" in self.model_mappings_config + encoding = AutoTokenizer.from_pretrained( + repo_id, cache_dir=weights_path, + trust_remote_code=True + ) + encoding.EOT = self.model_mappings_config["tokenizer"]["eot_idx"] + encoding.DIAMOND = self.model_mappings_config["tokenizer"]["padding_idx"] + encoding.PREFIX = self.model_mappings_config["tokenizer"]["fim_prefix"] + encoding.INFILL = self.model_mappings_config["tokenizer"]["fim_middle"] + encoding.SUFFIX = self.model_mappings_config["tokenizer"]["fim_suffix"] + encoding.ESCAPE = self.model_mappings_config["tokenizer"]["escape"] + return encoding + + def _map_model_specific_params( + self, + freeze_exceptions: List[str], + lora_target_modules: List[str] + ) -> Tuple[List[str], List[str]]: + return map_model_specific_params( + model_name=self.model_name, + freeze_exceptions=freeze_exceptions, + lora_target_modules=lora_target_modules + ) + + def _apply_model_modifiers( + self, + model: torch.nn.Module + ): + for modifier in self.model_mappings_config['train_model_modifiers']: + path, modifier_name = modifier.rsplit('.', maxsplit=1) + mod_path = importlib.import_module(f"self_hosting_machinery.finetune.modelling.{path}") + mod = getattr(mod_path, modifier_name) + try: + mod(model) + except Exception as e: + logging.error(f"Applying model modifier {mod_path} wasn't successful: {e}") + + def _set_low_gpu_mode( + self, + low_gpu_mode: bool + ): + force_low_gpu_mem_mode = hasattr(self.model, "force_low_gpu_mem_mode") and self.model.force_low_gpu_mem_mode + self.low_gpu_mem_mode = low_gpu_mode or force_low_gpu_mem_mode + logging.warning("Setting low_gpu_mem_mode={self.low_gpu_mem_mode} for the model") + + if self.low_gpu_mem_mode: + self.model.gradient_checkpointing_enable() + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.low_gpu_mem_hook = self.model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) + else: + self.model.gradient_checkpointing_disable() + if self.low_gpu_mem_hook is not None: + self.low_gpu_mem_hook.remove() diff --git a/self_hosting_machinery/finetune/utils/__init__.py b/self_hosting_machinery/finetune/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/self_hosting_machinery/finetune/utils/eta.py b/self_hosting_machinery/finetune/utils/eta.py new file mode 100644 index 00000000..89bcbfa6 --- /dev/null +++ b/self_hosting_machinery/finetune/utils/eta.py @@ -0,0 +1,34 @@ +from typing import List + +__all__ = ['EtaTracker'] + + +class EtaTracker: + def __init__(self, total_tasks: int): + self.total_tasks = total_tasks + self.observations: List[float] = [] + + def append(self, time: float): + assert len(self.observations) < self.total_tasks, "EtaTracker is full" + self.observations.append(time) + + def eta(self) -> float: + return self.average_time() * (self.total_tasks - len(self.observations)) + + def average_time(self, window_size: int = 5) -> float: + def _remove_outliers(data): + q1 = sorted(data)[int(len(data) * 0.25)] + q3 = sorted(data)[int(len(data) * 0.75)] + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + return [x for x in data if lower_bound <= x <= upper_bound] + + def _running_avg(data, window_size): + return [sum(data[i:i + window_size]) / window_size + for i in range(len(data) - window_size + 1)] + + observations = _remove_outliers(self.observations) + if len(observations) > (window_size * 2): + observations = _running_avg(observations, window_size=window_size) + return sum(observations) / len(observations) diff --git a/refact_data_pipeline/finetune/finetune_utils.py b/self_hosting_machinery/finetune/utils/finetune_utils.py similarity index 84% rename from refact_data_pipeline/finetune/finetune_utils.py rename to self_hosting_machinery/finetune/utils/finetune_utils.py index e65293a4..4fc3e23d 100644 --- a/refact_data_pipeline/finetune/finetune_utils.py +++ b/self_hosting_machinery/finetune/utils/finetune_utils.py @@ -1,14 +1,16 @@ import copy +import hashlib import os import json import time +from pathlib import Path -from refact_data_pipeline.finetune.finetune_train_defaults import finetune_train_defaults +from self_hosting_machinery.finetune.configuration.finetune_filtering_defaults import finetune_filtering_defaults +from self_hosting_machinery.finetune.configuration.finetune_train_defaults import finetune_train_defaults from self_hosting_machinery import env -from typing import Any, Dict, Optional, Callable - +from typing import Any, Dict, Optional, Callable, Union legacy_finetune_model = "CONTRASTcode/3b/multi" default_finetune_model = "Refact/1.6B" @@ -94,6 +96,14 @@ def get_active_lora(model_name: str, model_info: Dict[str, Any]) -> Dict: } +def get_finetune_filter_config(logger: Optional[Callable] = None): + cfg = {**finetune_filtering_defaults} + if os.path.exists(env.CONFIG_HOW_TO_FILTER): + logger("Reading %s" % env.CONFIG_HOW_TO_FILTER) + cfg.update(**json.load(open(env.CONFIG_HOW_TO_FILTER))) + return cfg + + def get_finetune_config(models_db: Dict[str, Any], logger: Optional[Callable] = None) -> Dict[str, Any]: cfg = copy.deepcopy(finetune_train_defaults) if os.path.exists(env.CONFIG_FINETUNE): @@ -155,3 +165,17 @@ def get_prog_and_status_for_ui() -> (str, str): return "prog_ftune", "starting" return prog, status + + +def get_file_digest(file_path: Union[Path, str]) -> str: + h = hashlib.sha256() + + with open(file_path, 'rb') as file: + while True: + # Reading is buffered, so we can read smaller chunks. + chunk = file.read(h.block_size) + if not chunk: + break + h.update(chunk) + + return h.hexdigest() diff --git a/self_hosting_machinery/finetune/utils/timer.py b/self_hosting_machinery/finetune/utils/timer.py new file mode 100644 index 00000000..e7b56443 --- /dev/null +++ b/self_hosting_machinery/finetune/utils/timer.py @@ -0,0 +1,21 @@ +import logging +import time + + +class Timer: + def __init__(self, message: str): + self._start_time = None + self._message_template = message + + def __call__(self, message: str): + self._message_template = message + + def __enter__(self): + self._start_time = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_time = time.time() - self._start_time + logging.info(self._message_template.format( + time_s=elapsed_time, + time_ms=elapsed_time * 1000 + )) diff --git a/refact_data_pipeline/finetune/traces.py b/self_hosting_machinery/finetune/utils/traces.py similarity index 95% rename from refact_data_pipeline/finetune/traces.py rename to self_hosting_machinery/finetune/utils/traces.py index 206f722f..1d9465a8 100644 --- a/refact_data_pipeline/finetune/traces.py +++ b/self_hosting_machinery/finetune/utils/traces.py @@ -10,7 +10,6 @@ from typing import Dict, Optional, TextIO, Any, List - _cx: Optional['TraceContext'] = None @@ -67,9 +66,9 @@ def _except_hook(exctype, value, tb): quit(1) sys.excepthook = _except_hook - # logging.basicConfig(level=logging.WARNING) + # More messages could be found in + # /home/user/.local/lib/python3.9/site-packages/deepspeed/ops/op_builder/builder.py:445 (verbose is True) logging.getLogger("DeepSpeed").setLevel(logging.WARNING) - # More messages in /home/user/.local/lib/python3.9/site-packages/deepspeed/ops/op_builder/builder.py:445 (verbose is True) if task_name == "NO_LOGS": return diff --git a/refact_data_pipeline/finetune/traces_plot.py b/self_hosting_machinery/finetune/utils/traces_plot.py similarity index 58% rename from refact_data_pipeline/finetune/traces_plot.py rename to self_hosting_machinery/finetune/utils/traces_plot.py index 37c76670..9b072a4a 100644 --- a/refact_data_pipeline/finetune/traces_plot.py +++ b/self_hosting_machinery/finetune/utils/traces_plot.py @@ -1,39 +1,40 @@ import collections -import sys -import re import copy import os -import jsonlines +import re +from multiprocessing import Process +from pathlib import Path from typing import List, Dict +import jsonlines +import matplotlib +import numpy as np + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import io -def smooth(y, radius): - import numpy as np + +__all__ = ['AsyncPlotter'] + +def smooth(y: np.array, radius: int, eps: float = 1e-20): kernel = np.zeros(2 * radius + 1) kernel[: radius + 1] = np.linspace(0, 1, radius + 2)[1:] assert kernel.size % 2 == 1 radius = kernel.size // 2 - EPS = 1e-20 - return ( - np.correlate(y, kernel, mode="full") - / (np.correlate(np.ones_like(y), kernel, mode="full") + EPS) - )[radius:-radius] + num = np.correlate(y, kernel, mode="full") + denom = np.correlate(np.ones_like(y), kernel, mode="full") + eps + return (num / denom)[radius:-radius] def plot( - xaxis: str, - x0: float, - x1: float, - yaxis: str, - jdict: Dict[str, List[Dict[str, float]]], - colors: List[str], + xaxis: str, + x0: float, + x1: float, + yaxis: str, + jdict: Dict[str, List[Dict[str, float]]], + colors: List[str], ): - import numpy as np - import matplotlib - matplotlib.use('Agg') - import matplotlib.pyplot as plt - import io - xs = collections.defaultdict(list) ys = collections.defaultdict(list) smoo = 0 @@ -115,26 +116,47 @@ def plot( return buf -if __name__ == "__main__": - jdict = {} - jdict["test"] = [] - jdict["train"] = list(jsonlines.open(sys.argv[1])) - for line in jdict["train"]: - line = copy.deepcopy(line) - if "test_loss" in line: - line["loss"] = line["test_loss"] - jdict["test"].append(line) - if len(jdict["test"]) == 0: - jdict.pop("test") - buf = plot( - "iteration", - 0, - int(sys.argv[2]), - "loss[0,2.6]", #,smooth5 - jdict, - ["#ff0000", "#880000"], - ) - # save - with open("progress.svg.tmp", "wb") as f: - f.write(buf.getvalue()) - os.rename("progress.svg.tmp", "progress.svg") +class AsyncPlotter: + def __init__( + self, + run_path: Path, + progress_filename: str, + iters: int, + ): + self._run_path = Path(run_path) + assert self._run_path.exists() + self._progress_filename = self._run_path / progress_filename + self._output_filename = self._run_path / "progress.svg" + self._iters = iters + int(0.1 * iters) + self.process = None + + def _plot_fn(self): + jdict = { + "test": [], + "train": list(jsonlines.open(self._progress_filename)) + } + for line in jdict["train"]: + line = copy.deepcopy(line) + if "test_loss" in line: + line["loss"] = line["test_loss"] + jdict["test"].append(line) + if len(jdict["test"]) == 0: + jdict.pop("test") + buf = plot( + "iteration", + 0, + self._iters, + "loss[0,2.6]", # ,smooth5 + jdict, + ["#ff0000", "#880000"], + ) + with open(self._output_filename.with_suffix(".tmp"), "wb") as f: + f.write(buf.getvalue()) + os.rename(self._output_filename.with_suffix(".tmp"), self._output_filename) + + def plot_async(self): + if self.process is not None: + self.process.join() + assert self._progress_filename.exists() + self.process = Process(target=self._plot_fn) + self.process.start() diff --git a/self_hosting_machinery/inference/inference_hf.py b/self_hosting_machinery/inference/inference_hf.py index 5a522f24..4b5d4a2f 100644 --- a/self_hosting_machinery/inference/inference_hf.py +++ b/self_hosting_machinery/inference/inference_hf.py @@ -185,7 +185,7 @@ def cache_dir(self) -> str: def _dump_embeddings(self): try: - from refact_data_pipeline.finetune import supported_models + from self_hosting_machinery.finetune.configuration import supported_models except ImportError: raise ImportError("please install refact_data_pipeline") if self._model_name not in supported_models.config: @@ -199,7 +199,7 @@ def _dump_embeddings(self): def load_embeddings(self): try: - from refact_data_pipeline.finetune import supported_models + from self_hosting_machinery.finetune.configuration import supported_models except ImportError: raise ImportError("please install refact_data_pipeline") diff --git a/self_hosting_machinery/inference/lora_loader_mixin.py b/self_hosting_machinery/inference/lora_loader_mixin.py index 86e0baf9..09373442 100644 --- a/self_hosting_machinery/inference/lora_loader_mixin.py +++ b/self_hosting_machinery/inference/lora_loader_mixin.py @@ -8,7 +8,7 @@ from refact_models.checkpoint_loader import load_finetune_checkpoint from refact_models.checkpoint_loader import load_finetune_checkpoint_only -from refact_data_pipeline.finetune.finetune_utils import get_active_loras +from self_hosting_machinery.finetune.utils.finetune_utils import get_active_loras from self_hosting_machinery import env diff --git a/self_hosting_machinery/scripts/best_lora.py b/self_hosting_machinery/scripts/best_lora.py index 1032105c..7ba39258 100644 --- a/self_hosting_machinery/scripts/best_lora.py +++ b/self_hosting_machinery/scripts/best_lora.py @@ -3,8 +3,8 @@ import json from self_hosting_machinery import env -from refact_data_pipeline.finetune.finetune_utils import get_run_model_name -from refact_data_pipeline.finetune.finetune_utils import default_finetune_model +from self_hosting_machinery.finetune.utils.finetune_utils import get_run_model_name +from self_hosting_machinery.finetune.utils.finetune_utils import default_finetune_model from typing import Dict, Optional diff --git a/self_hosting_machinery/scripts/env.py b/self_hosting_machinery/scripts/env.py index e19969e3..91947bb5 100644 --- a/self_hosting_machinery/scripts/env.py +++ b/self_hosting_machinery/scripts/env.py @@ -14,6 +14,10 @@ DIR_SSH_KEYS = os.path.join(PERMDIR, "ssh-keys") DIR_UNPACKED = os.path.join(TMPDIR, "unpacked-files") +TRAIN_UNFILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "train_set.jsonl") +TRAIN_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "train_set_filtered.jsonl") +TEST_UNFILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set.jsonl") +TEST_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set_filtered.jsonl") CONFIG_ENUM_GPUS = os.path.join(DIR_CONFIG, "gpus_enum_result.out") CONFIG_BUSY_GPUS = os.path.join(DIR_CONFIG, "gpus_busy_result.out") diff --git a/self_hosting_machinery/watchdog/watchdog.d/filetune.cfg b/self_hosting_machinery/watchdog/watchdog.d/filetune.cfg index bd3cd54c..8d60d391 100644 --- a/self_hosting_machinery/watchdog/watchdog.d/filetune.cfg +++ b/self_hosting_machinery/watchdog/watchdog.d/filetune.cfg @@ -5,6 +5,6 @@ "save_status": "%CONFIG_FINETUNE_STATUS%", "save_status_nickname": "prog_ftune", "at_night": {}, - "command_line": ["python", "-m", "refact_data_pipeline.finetune.finetune_sequence"], + "command_line": ["python", "-m", "self_hosting_machinery.finetune.scripts.finetune_sequence"], "gpus": [0] } diff --git a/self_hosting_machinery/watchdog/watchdog.d/filetune_filter_only.cfg b/self_hosting_machinery/watchdog/watchdog.d/filetune_filter_only.cfg index c0dd9dc9..3c7b836b 100644 --- a/self_hosting_machinery/watchdog/watchdog.d/filetune_filter_only.cfg +++ b/self_hosting_machinery/watchdog/watchdog.d/filetune_filter_only.cfg @@ -5,6 +5,6 @@ "save_status": "%CONFIG_FINETUNE_STATUS%", "save_status_nickname": "prog_filter", "at_night": {}, - "command_line": ["python", "-m", "refact_data_pipeline.finetune.finetune_sequence", "--filter-only"], + "command_line": ["python", "-m", "self_hosting_machinery.finetune.scripts.finetune_sequence", "--filter-only"], "gpus": [0] } diff --git a/self_hosting_machinery/watchdog/watchdog.d/process_uploaded.cfg b/self_hosting_machinery/watchdog/watchdog.d/process_uploaded.cfg index 07394a3d..637c15de 100644 --- a/self_hosting_machinery/watchdog/watchdog.d/process_uploaded.cfg +++ b/self_hosting_machinery/watchdog/watchdog.d/process_uploaded.cfg @@ -3,6 +3,6 @@ "when_file_appears": "%FLAG_LAUNCH_PROCESS_UPLOADS%", "save_status": "%CONFIG_FINETUNE_STATUS%", "save_status_nickname": "prog_linguist", - "command_line": ["python", "-m", "refact_data_pipeline.finetune.process_uploaded_files"], + "command_line": ["python", "-m", "self_hosting_machinery.finetune.scripts.process_uploaded_files"], "gpus": [] } diff --git a/self_hosting_machinery/webgui/selfhost_model_assigner.py b/self_hosting_machinery/webgui/selfhost_model_assigner.py index 0ffd760e..9872a03e 100644 --- a/self_hosting_machinery/webgui/selfhost_model_assigner.py +++ b/self_hosting_machinery/webgui/selfhost_model_assigner.py @@ -5,11 +5,11 @@ from dataclasses import dataclass, field from self_hosting_machinery import env +from self_hosting_machinery.finetune.utils.finetune_utils import get_active_loras from self_hosting_machinery.webgui.selfhost_webutils import log from known_models_db.refact_known_models import models_mini_db from known_models_db.refact_toolbox_db import modelcap_records from self_hosting_machinery.scripts.best_lora import find_best_lora -from refact_data_pipeline.finetune.finetune_utils import get_active_loras from typing import List, Dict, Set, Any diff --git a/self_hosting_machinery/webgui/static/tab-finetune.html b/self_hosting_machinery/webgui/static/tab-finetune.html index 42331f3d..4f9e2d18 100644 --- a/self_hosting_machinery/webgui/static/tab-finetune.html +++ b/self_hosting_machinery/webgui/static/tab-finetune.html @@ -135,9 +135,9 @@