Skip to content

Commit

Permalink
codellama/7b fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT committed Nov 1, 2023
1 parent 3428edf commit d233683
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 93,477 deletions.
93 changes: 84 additions & 9 deletions refact_data_pipeline/filters_fim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def _generate_plain_text(self, tokens, cursor, sample, stats) \
return {
"tokens": plain,
"mask": mask,
"first": [1] + [0] * (len(plain) - 1),
"stats": {**sample["stats"], **stats},
}, cursor

Expand Down Expand Up @@ -316,29 +315,105 @@ def _generate_fim(self, tokens, cursor, sample, stats) \
if hasattr(self.enc, 'encode_stochastic'):
prefix_toks, _ = self.enc.encode_stochastic(prefix, [], 0.01 * self.tkr_stochastic_tokens)
suffix_toks, _ = self.enc.encode_stochastic(suffix, [], 0.01 * self.tkr_stochastic_tokens)
middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01 * self.tkr_stochastic_tokens)
else:
prefix_toks = self.enc.encode(prefix)
suffix_toks = self.enc.encode(suffix)
middle_toks = self.enc.encode(middle)

tokens, mask = self._fim_format(
prefix_toks=prefix_toks, suffix_toks=suffix_toks, middle_toks=middle_toks
)

stats["fim_out"] += 1
return {
"tokens": tokens,
"mask": mask,
"stats": {**sample["stats"], **stats},
}, cursor

def _fim_format(
self,
prefix_toks: List[int],
middle_toks: List[int],
suffix_toks: List[int],
):
if self.random.random() < 0.5:
tokens_context = [self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks
mask_context = [0] + [1] * len(prefix_toks) + [0] + [1] * len(suffix_toks)
else:
tokens_context = [self.enc.SUFFIX] + suffix_toks + [self.enc.PREFIX] + prefix_toks
mask_context = [0] + [1] * len(suffix_toks) + [0] + [1] * len(prefix_toks)
if hasattr(self.enc, 'encode_stochastic'):
middle_toks, _ = self.enc.encode_stochastic(middle, [], 0.01 * self.tkr_stochastic_tokens)
else:
middle_toks = self.enc.encode(middle)

middle_mask = [1] * len(middle_toks)
stats["fim_out"] += 1
if self.debug:
print(f'splitter: {splitter}, middle_size: {len(middle)}, middle: {middle}')
print(termcolor.colored(self.enc.decode(prefix_toks), "red"), end='')
print(termcolor.colored(self.enc.decode(middle_toks), "green"), end='')
print(termcolor.colored(self.enc.decode(suffix_toks), "red"))

tokens = tokens_context + [self.enc.INFILL] + middle_toks + [self.enc.EOT]
mask = mask_context + [0] + middle_mask + [1]

return tokens, mask


class FIMv2CodeLlama(FIMv2):
def _generate_plain_text(self, tokens, cursor, sample, stats) \
-> Tuple[Optional[Dict[str, Union[str, List[str]]]], int]:
assert self.enc.BOS is not None
plain = tokens[cursor: cursor + self.n_ctx]
cursor += len(plain)
is_cut_file = len(tokens[cursor:]) > 0
mask = [1] * len(plain)
plain.append(self.enc.EOT)
# If last_chunk then the EOT is real, the model should predict it. If not, it just
# acts as a separator, the model should not predict it.
# And it's not visible anyway if len(plain) > n_ctx
if is_cut_file:
mask.append(0)
else:
mask.append(1)

return {
"tokens": tokens_context + [self.enc.INFILL] + middle_toks + [self.enc.EOT],
"mask": mask_context + [0] + middle_mask + [1],
"first": [1] + [0] * (-1 + len(tokens_context) + 1 + len(middle_toks) + 1),
"tokens": [self.enc.BOS] + plain,
"mask": [0] + mask,
"stats": {**sample["stats"], **stats},
}, cursor


def _fim_format(
self,
prefix_toks: List[int],
middle_toks: List[int],
suffix_toks: List[int],
):
assert self.enc.BOS is not None
# https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L380
if self.random.random() < 0.5:
tokens = (
[self.enc.BOS, self.enc.PREFIX] + prefix_toks
+ [self.enc.SUFFIX] + suffix_toks
+ [self.enc.INFILL] + middle_toks
+ [self.enc.EOT]
)
mask = (
[0, 0] + ([1] * len(prefix_toks))
+ [0] + ([1] * len(suffix_toks))
+ [0] + ([1] * len(middle_toks))
+ [1]
)
else:
tokens = (
[self.enc.BOS, self.enc.PREFIX, self.enc.SUFFIX]
+ suffix_toks + [self.enc.INFILL]
+ prefix_toks + middle_toks
+ [self.enc.EOT]
)
mask = (
[0, 0, 0]
+ ([1] * len(suffix_toks)) + [0]
+ ([1] * len(prefix_toks)) + ([1] * len(middle_toks))
+ [1]
)
return tokens, mask
2 changes: 1 addition & 1 deletion refact_data_pipeline/filters_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
self.pack_complete: bool = dataopts.get('pack_complete', 1) == 1
self.drop_less_than_t: int = dataopts.get('pack_drop_less_than_t', 6)
self.buffer_size: int = dataopts.get('pack_buffer_size', 256)
self.keys = dataopts.get('packer_keys', 'tokens;mask;first').split(';')
self.keys = dataopts.get('packer_keys', 'tokens;mask').split(';')
self.max_packing_rounds = 8
self.do_nothing_keys = ['stats']
assert len(self.keys) > 0
Expand Down
Empty file.
11 changes: 10 additions & 1 deletion refact_data_pipeline/finetune_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from refact_data_pipeline import DatasetOpts
from refact_data_pipeline import pipeline_pieces as pp
from refact_data_pipeline.filters_fim_v2 import FIMv2
from refact_data_pipeline.filters_fim_v2 import FIMv2, FIMv2CodeLlama
from self_hosting_machinery import env

__all__ = [
Expand Down Expand Up @@ -145,3 +145,12 @@ def _build_pipeline(self, files: List[Dict[str, Any]]):
ds = pp.DensePacker(ds, self._ds_options)
ds = pp.Shuffle(ds, self._ds_options)
return ds


class CodeLLamaFIMDataset(RefactDataset):
def _build_pipeline(self, files: List[Dict[str, Any]]):
ds = ReadFileByFile(files, self._ds_options)
ds = FIMv2CodeLlama(ds, self._ds_options)
ds = pp.DensePacker(ds, self._ds_options)
ds = pp.Shuffle(ds, self._ds_options)
return ds
Loading

0 comments on commit d233683

Please sign in to comment.