Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune section refactoring #182

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
COPY . /tmp/app
RUN pip install /tmp/app && rm -rf /tmp/app

RUN git clone -b feat/alibi https://github.com/smallcloudai/flash-attention.git /tmp/flash-attention \
&& cd /tmp/flash-attention \
&& MAX_JOBS=8 python3 setup.py install

ENV REFACT_PERM_DIR "/perm_storage"
ENV REFACT_TMP_DIR "/tmp"

Expand Down
12 changes: 6 additions & 6 deletions code_contrast/format_2022q3/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def from_odm_dict(
if tight_shrink:
files.reverse()
else:
random.shuffle(files)
np_random.shuffle(files)
file_poi = defaultdict(set)
file_deltokens = defaultdict(list)
file_dellines = defaultdict(list)
Expand All @@ -135,7 +135,7 @@ def from_odm_dict(
# dest_lines = odm["dest"][fn].replace('\r\n', '\n').replace('\r', '\n').splitlines()
orig_lines = [x+"\n" for x in odm["orig"][fn].splitlines()]
dest_lines = [x+"\n" for x in odm["dest"][fn].splitlines()]
if len(orig_lines)==0:
if len(orig_lines) == 0:
orig_lines.append("\n")
if orig_lines[-1][-1] != "\n":
orig_lines[-1] += "\n"
Expand Down Expand Up @@ -228,7 +228,7 @@ def orig_app(line):
opblocks.append(opblock)
self.orig_tokens[fn] = orig_all_tokens
self.dest_tokens[fn] = dest_all_tokens
random.shuffle(opblocks)
np_random.shuffle(opblocks)
raw_ops: List[Tuple[str, str, int, int, int, int]] = list()
for opblock in opblocks:
raw_ops.extend(opblock)
Expand Down Expand Up @@ -359,7 +359,7 @@ def app(t, m):
self.fn2tstart = dict()
self.fn2cut = dict()
tpos_unused = list(self.enc.tpos)
random.shuffle(tpos_unused)
np_random.shuffle(tpos_unused)
tpos_unused *= 2
need_to_cut_main = 0
need_to_cut_supp = 0
Expand Down Expand Up @@ -403,9 +403,9 @@ def app(t, m):
else:
move_r2 = min(cut_step, cut_more, relax2[fn])
else:
if random.random() < 0.5 and relax1[fn] > 1:
if np_random.random() < 0.5 and relax1[fn] > 1:
move_r1 = random.randint(0, min(cut_more, relax1[fn]))
if random.random() < 0.5 and relax2[fn] > 1:
if np_random.random() < 0.5 and relax2[fn] > 1:
move_r2 = random.randint(0, min(cut_more, relax2[fn]))
assert move_r1 >= 0 and move_r2 >= 0, f"i1={i1} i2={i2} r1={r1} r2={r2}"
if SHRINK_DUMP:
Expand Down
4 changes: 2 additions & 2 deletions code_contrast/format_2022q3/contrast_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def poisson():
for n in range(1, len(result)-1):
lop, li1, li2, lj1, lj2 = result[n-1]
mop, mi1, mi2, mj1, mj2 = result[n]
if lop == "equal" and mop != "equal" and random.random() < left_prob:
if lop == "equal" and mop != "equal" and np_random.random() < left_prob:
assert li2 == mi1
if exact_cx_lines0 >= 0:
move = exact_cx_lines0
Expand All @@ -104,7 +104,7 @@ def poisson():
mop, mi1, mi2, mj1, mj2 = result[n]
rop, ri1, ri2, rj1, rj2 = result[n+1]
# if mop != "equal" and rop == "equal" and (random.random() < right_prob or (mi1==mi2 and disable_insert)):
if mop != "equal" and rop == "equal" and random.random() < right_prob:
if mop != "equal" and rop == "equal" and np_random.random() < right_prob:
assert ri1 == mi2
if exact_cx_lines1 >= 0:
move = exact_cx_lines1
Expand Down
11 changes: 6 additions & 5 deletions code_contrast/format_2023q2/from_orig_dest_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def from_odm_dict(
exact_cx_lines1 = -1,
external_poi_ranges: Optional[DefaultDict[str, List[Tuple[int, int]]]] = None,
want_cursor_token: bool = False,
random_state: np.random.RandomState = np.random.RandomState(42)
) -> Tuple[Packer, int]:
pack = Packer(fmt)
files1 = list(odm["orig"].keys())
Expand All @@ -33,7 +34,7 @@ def from_odm_dict(
# This moves it to the end, more visible to the model
fns.reverse()
else:
random.shuffle(fns)
random_state.shuffle(fns)
files = []
chunks: List[ChunkElement] = []
for fn in fns:
Expand All @@ -53,7 +54,7 @@ def from_odm_dict(
if fn not in odm["dest"]:
continue
chunks.extend(_run_diff_for_single_file(f, [(x + "\n") for x in odm["dest"][fn].splitlines()], exact_cx_lines0, exact_cx_lines1))
random.shuffle(chunks)
random_state.shuffle(chunks)
for chunk in chunks:
pack.add_to_plan(chunk)
if want_cursor_token and len(chunks) == 1:
Expand All @@ -62,9 +63,9 @@ def from_odm_dict(
thischunk_lines = set(range(chunks[0].line_n, chunks[0].line_n + len(chunks[0].to_del) + 1))
thischunk_modlines = list(thischunk_lines & modlines)
if len(thischunk_modlines) > 0: # Can be zero for whatever reason, cursor appearance is random anyway
aim = random.choice(thischunk_modlines)
shift = np.random.poisson(2)
sign = np.random.choice([-1, 1])
aim = random_state.choice(thischunk_modlines)
shift = random_state.poisson(2)
sign = random_state.choice([-1, 1])
file0._cursor_token_at_line = aim + shift * sign
return pack, msg_plan_n

Expand Down
10 changes: 5 additions & 5 deletions known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
"chat_scratchpad_class": None,
"model_class_kwargs": {},
"required_memory_mb": 6000,
"T": 4096,
"required_memory_mb": 8000,
"T": 8192,
"filter_caps": ["completion", "finetune"],
},
"starcoder/3b/base": {
Expand All @@ -45,7 +45,7 @@
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
"chat_scratchpad_class": None,
"model_class_kwargs": {},
"required_memory_mb": 9000,
"required_memory_mb": 12000,
"T": 4096,
"filter_caps": ["completion", "finetune"],
},
Expand All @@ -55,8 +55,8 @@
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
"chat_scratchpad_class": None,
"model_class_kwargs": {},
"required_memory_mb": 18000,
"T": 2048,
"required_memory_mb": 20000,
"T": 4096,
"filter_caps": ["completion", "finetune"],
},
"wizardcoder/15b": {
Expand Down
135 changes: 99 additions & 36 deletions refact_data_pipeline/datautils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

import torch as th
from collections import defaultdict
from typing import Iterator, Tuple, Dict, Any, Callable, Sequence
from typing import Iterator, Tuple, Dict, Any, Callable, Iterable, List

from refact_data_pipeline import DatasetOpts


def str2dtype(s: str) -> th.dtype:
Expand All @@ -14,6 +18,66 @@ def str2dtype(s: str) -> th.dtype:
}[s]


_prefer_dtypes = {
"logits": th.int64,
"first": th.bool,
"mask": th.bool
}


def _after_collate(result: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
if 'first' in result:
result['first'] = result.pop("first")[:, :-1]
if 'mask' in result:
result['mask'] = result.pop("mask")[:, 1:]
result["labels"] = result["tokens"][:, 1:]
result["input"] = result["tokens"][:, :-1]
return {
k: (v if isinstance(v, th.Tensor) else v)
for k, v in result.items()
}


def collate_fn(records: List[Dict[str, Any]]) -> Dict[str, Any]:
output = defaultdict(list)
last_stats = None
for idx, record in enumerate(records):
for k, v in record.items():
if k == "stats":
last_stats = v
continue
output[k].append(
th.tensor(record[k], dtype=_prefer_dtypes.get(k, th.int64))
)
return _after_collate({
"stats": last_stats,
**{k: th.stack(v).contiguous() for k, v in output.items()}
})


def data_parallel_split_and_collate_fn(records: List[Dict[str, Any]]) -> Dict[str, Any]:
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))

output = defaultdict(list)
last_stats = None
for idx, record in enumerate(records):
for k, v in record.items():
if k == "stats":
last_stats = v
continue
output[k].append(
th.tensor(record[k], dtype=_prefer_dtypes.get(k, th.int64))
)
assert len(records) % world_size == 0, "effective batch size %s" % len(records)
effective_bs = len(records) // world_size
from_, to = rank * effective_bs, (rank + 1) * effective_bs
return _after_collate({
"stats": last_stats,
**{k: th.stack(v)[from_:to].contiguous() for k, v in output.items()}
})


def read_and_collate(
data_iter: Iterator,
prefer_dtypes: Dict[str, str],
Expand Down Expand Up @@ -58,45 +122,44 @@ def read_and_collate(
class BatchIterator:
def __init__(
self,
seq: Sequence,
dataopts: Dict[str, Any],
inner_filter: Iterable[Any],
dataopts: DatasetOpts
):
self.seq_iter = iter(seq)
self.inner_filter = inner_filter
self.dataopts = dataopts
self.batch_size = dataopts.get("batch_size", 1)
self.device = dataopts.get("device", "cuda")
self.drop_last = dataopts.get("drop_last", False)

def __next__(self):
data, datastats = read_and_collate(
data_iter=self.seq_iter,
prefer_dtypes=dict(mask='torch.bool', first='torch.bool'),
B=self.batch_size,
device=self.device,
cold_restart_dict=dict(),
log_stats=True,
progress_callback=lambda *args, **kwargs: None
)
if len(data) == 0:
raise StopIteration()

if self.drop_last and len(data['tokens']) < self.batch_size:
raise StopIteration()

extra = dict()
if 'first' in data:
extra['first'] = data.pop("first")[:, :-1]
if 'mask' in data:
extra['mask'] = data.pop("mask")[:, 1:]

tokens = data.pop("tokens")
batch = dict(
labels=tokens[:, 1:],
input=tokens[:, :-1],
**extra
)
batch.update({k: v for k, v in data.items() if k not in batch})
return batch, datastats

def __iter__(self):
return self
seq_iter = iter(self.inner_filter)
while True:
data, datastats = read_and_collate(
data_iter=seq_iter,
prefer_dtypes=dict(mask='torch.bool', first='torch.bool'),
B=self.batch_size,
device=self.device,
cold_restart_dict=dict(),
log_stats=True,
progress_callback=lambda *args, **kwargs: None
)
if len(data) == 0:
break

if self.drop_last and len(data['tokens']) < self.batch_size:
break

extra = dict()
if 'first' in data:
extra['first'] = data.pop("first")[:, :-1]
if 'mask' in data:
extra['mask'] = data.pop("mask")[:, 1:]

tokens = data.pop("tokens")
batch = dict(
labels=tokens[:, 1:],
input=tokens[:, :-1],
**extra
)
batch.update({k: v for k, v in data.items() if k not in batch})
yield batch, datastats
8 changes: 5 additions & 3 deletions refact_data_pipeline/filters_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ def __init__(
self.inner_filter = inner_filter
self.n_ctx = dataopts.get("n_ctx", 2048)
self.no_format_prob = dataopts.get("chat_no_format_prob", 0.0)
self.chat_random_seed = dataopts.get("chat_random_seed", 42)
self.debug = bool(dataopts.get("debug", 0))
self.tkr_stochastic_tokens = bool(dataopts.get("tkr_stochastic_tokens", 0.0))
self.enc: RefactEncoding = dataopts.encoding
self.fmt: Format2023q2 = format.format_2023q2_escape(self.enc)
self.random = np.random.RandomState(self.chat_random_seed)
self.random = np.random.RandomState(dataopts.get("seed", 42))

def _pack_format(self, plan: List[MsgElement], odm: Dict, stats: Dict):
try:
Expand Down Expand Up @@ -87,7 +86,10 @@ def _pack_plain(self, plan: List[MsgElement], odm: Dict, stats: Dict):
if self.debug:
print(f'Chat2023Q2:\n{text}\n\n')

tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens)
if hasattr(self.enc, 'encode_stochastic'):
tokens, _ = self.enc.encode_stochastic(text, [], 0.01 * self.tkr_stochastic_tokens)
else:
tokens = self.enc.encode(text)
tokens += [self.enc.EOT]
emit = {
"tokens": tokens,
Expand Down
13 changes: 9 additions & 4 deletions refact_data_pipeline/filters_diff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import random
import traceback
import copy

import numpy as np

from refact_encoding import RefactEncoding
from refact_data_pipeline import DatasetOpts
from refact_data_pipeline.finetune import traces
from code_contrast.format_2022q3 import contrast

from typing import Dict
Expand All @@ -18,6 +20,8 @@ def __init__(self,
self.enc: RefactEncoding = dataopts.encoding
self.n_ctx = dataopts.get("n_ctx", 2048)
self.selftest = dataopts.get("selftest", 0)
self.random = random.Random(dataopts.get("seed", 42))
self.np_random = np.random.RandomState(dataopts.get("seed", 42))

def __iter__(self):
stats: Dict[str, int] = {
Expand All @@ -33,7 +37,7 @@ def __iter__(self):
if source_files_empty_cnt == len(odm["orig"]):
stats["diffskip_onlyadd"] += 1
continue
make_no_changes = random.random() < 0.05
make_no_changes = self.random.random() < 0.05
if make_no_changes:
odm["orig"] = copy.deepcopy(odm["dest"])
if self.selftest:
Expand All @@ -52,6 +56,7 @@ def __iter__(self):
diff.from_odm_dict(
odm,
n_ctx=self.n_ctx,
np_random=self.np_random
)
if len(diff.edits) == 0 and not make_no_changes:
stats["diffskip_noedit"] += 1
Expand All @@ -73,8 +78,8 @@ def __iter__(self):
stats["diffskip_toobig"] += 1
continue
except Exception as e:
traces.log(str(odm))
traces.log(traceback.format_exc())
logging.error(str(odm))
logging.error(traceback.format_exc())
stats["diffskip_failed"] += 1
continue
edits_within_context = self.n_ctx - diff.offset_edits
Expand Down
Loading
Loading