Skip to content

Commit

Permalink
add deepseek inference and finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT committed Nov 6, 2023
1 parent 122833a commit 49c6d43
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 3 deletions.
22 changes: 22 additions & 0 deletions known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,26 @@
"T": 2048,
"filter_caps": ["wizardlm"],
},
"deepseek-ai/deepseek-coder-1.3b-base": {
"backend": "transformers",
"model_path": "deepseek-ai/deepseek-coder-1.3b-base",
"diff_scratchpad_class": "refact_scratchpads:ScratchpadDeepSeekCoderFIM",
"chat_scratchpad_class": None,
"model_class_kwargs": {
"load_in_4bit": True,
},
"T": 4096,
"filter_caps": ["completion", "finetune"],
},
"deepseek-ai/deepseek-coder-6.7b-base": {
"backend": "transformers",
"model_path": "deepseek-ai/deepseek-coder-6.7b-base",
"diff_scratchpad_class": "refact_scratchpads:ScratchpadDeepSeekCoderFIM",
"chat_scratchpad_class": None,
"model_class_kwargs": {
"load_in_4bit": True,
},
"T": 4096,
"filter_caps": ["completion", "finetune"],
},
}
7 changes: 4 additions & 3 deletions refact_data_pipeline/filters_fim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def __init__(
self.fim_probability = dataopts.get("fim_probability", 0.5)
self.tkr_stochastic_tokens = dataopts.get("tkr_stochastic_tokens", 3)
self.fim_drop_residuals = bool(dataopts.get("fim_drop_residuals", 0))
self.random_trim_context_prob = bool(dataopts.get("random_trim_context_prob", 0.0))
self.random_trim_context_prob = dataopts.get("random_trim_context_prob", 0.0)
self.spm_prob = dataopts.get("spm_prob", 0.5)
self.debug = bool(dataopts.get("debug", 0))
self.enc = dataopts.encoding
if hasattr(self.enc, "set_random_seed"):
Expand Down Expand Up @@ -338,7 +339,7 @@ def _fim_format(
middle_toks: List[int],
suffix_toks: List[int],
):
if self.random.random() < 0.5:
if self.random.random() < self.spm_prob:
tokens_context = [self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks
mask_context = [0] + [1] * len(prefix_toks) + [0] + [1] * len(suffix_toks)
else:
Expand Down Expand Up @@ -390,7 +391,7 @@ def _fim_format(
):
assert self.enc.BOS is not None
# https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L380
if self.random.random() < 0.5:
if self.random.random() < self.spm_prob:
tokens = (
[self.enc.BOS, self.enc.PREFIX] + prefix_toks
+ [self.enc.SUFFIX] + suffix_toks
Expand Down
1 change: 1 addition & 0 deletions refact_scratchpads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from refact_scratchpads.scratchpad_hf import ScratchpadPSM
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaPSM
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaSPM
from refact_scratchpads.scratchpad_hf import ScratchpadDeepSeekCoderFIM
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceStarChat
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceWizard
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceLlama2
Expand Down
73 changes: 73 additions & 0 deletions refact_scratchpads/scratchpad_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,79 @@ def _prompt_format(self, prefix_tokens, suffix_tokens):
self._fim_middle,
]

class ScratchpadDeepSeekCoderFIM(ScratchpadHuggingfaceBase):

def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, cursor1: int, **kwargs):
super().__init__(**kwargs)

assert cursor0 == cursor1

self._cursor_file = cursor_file
self._cursor = cursor0
self._code = sources[cursor_file]

self._prefix: Optional[str] = None
self._suffix: Optional[str] = None
self._suffix_line0cut: Optional[str] = None
self._completion = []

self._tokens_produced = 0
self._fim_prefix = self._encode_one_token("<|fim▁begin|>")
self._fim_suffix = self._encode_one_token("<|fim▁hole|>")
self._fim_middle = self._encode_one_token("<|fim▁end|>")
self._fim_eot = self._encode_one_token("<|EOT|>")
self._special_tokens.update({
self._fim_prefix, self._fim_suffix, self._fim_middle, self._fim_eot,
})

def _prompt_format(self, prefix_tokens, suffix_tokens):
return [
self._fim_prefix, *prefix_tokens,
self._fim_suffix, *suffix_tokens,
self._fim_middle,
]

def prompt(self, T: int):
self._prefix = self._code[:self._cursor]
# Why we need to cut the line right of the cursor?
# Example 1:
# function_call(param1, GENERATED_TONENS<EOF>)
# => everything works right
# Example 2:
# function_call(param1, GENERATED_TONENS)\nMORE_TOKENS\nSOME_OTHER_CALL(OTHER_PARAM<EOF>)
# ^^ but we stop here because we need single line completion
# => we have two closing parenthesis.
# self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
self._suffix = self._code[self._cursor:].lstrip(" \t")
self._suffix_line0cut = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
self._completion.clear()

prefix_cut, suffix_cut = trim_context_infill(
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens
)
prefix_cut_tokens = self._encode_without_special_tokens(prefix_cut)
suffix_cut_tokens = self._encode_without_special_tokens(suffix_cut)
self.debuglog(
"ScratchpadFIM prompt prefix %d chars -> %d tokens, suffix %d chars -> %d tokens, T=%d max_new_tokens=%d" %
(len(prefix_cut), len(prefix_cut_tokens), len(suffix_cut), len(suffix_cut_tokens), T, self._max_tokens)
)
prompt: List[int] = self._prompt_format(prefix_cut_tokens, suffix_cut_tokens)
self.debuglog("-"*40)
self.debuglog(self._tokenizer.decode(prompt))
self.debuglog("-"*40)
return prompt

def completion(self, final: bool):
assert self._prefix is not None
assert self._suffix is not None
completion = self._tokenizer.decode(self._completion)
if self.finish_reason == "eot":
# Correct stop
return {self._cursor_file: self._prefix + completion + self._suffix}
else:
# "stop-lf" or "length" or not stopped yet (empty reason), it's better to remove first line remainder
return {self._cursor_file: self._prefix + completion + self._suffix_line0cut}


class ScratchpadChatBase(ScratchpadHuggingfaceBase):

Expand Down
37 changes: 37 additions & 0 deletions self_hosting_machinery/finetune/configuration/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@
"force_enable_checkpointing": False
}

_deepseek_base = {
"lora_target_modules_mapping": {
"qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"out": ["self_attn.o_proj"],
"backproj": ["self_attn.o_proj"],
"mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"],
},
"freeze_exceptions_mapping": {
"wte": ["embed_tokens"],
"lm_head": ["lm_head"],
"lora": ["lora"]
},
"tokenizer": {
"eot_idx": 32021, # `<|EOT|>`
"padding_idx": 32018, # `<pad>`
"fim_prefix": 32016, # `<|fim▁begin|>`
"fim_middle": 32017, # `<|fim▁end|>`
"fim_suffix": 32015, # `<|fim▁hole|>`
"escape": 32013, # using `<|begin▁of▁sentence|>` token for now
},
"train_ds_pipeline": {
"ds_opts": f"{_fim_train_ds_pipeline['ds_opts']},spm_prob=0.0",
"ds_name": _fim_train_ds_pipeline["ds_name"]
},
"test_ds_pipeline": _fim_test_ds_pipeline,
"train_model_modifiers": [
"flash_sa.apply_flash_mha_to_codellama_model"
]
}

config = {
"Refact/1.6B": {
"lora_target_modules_mapping": {
Expand Down Expand Up @@ -105,5 +135,12 @@
"flash_sa.apply_flash_mha_to_codellama_model"
],
"force_enable_checkpointing": True
},

"deepseek-ai/deepseek-coder-1.3b-base": _deepseek_base,

"deepseek-ai/deepseek-coder-6.7b-base": {
**_deepseek_base,
"force_enable_checkpointing": True
}
}
1 change: 1 addition & 0 deletions self_hosting_machinery/finetune/modelling/flash_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
):
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

Expand Down

0 comments on commit 49c6d43

Please sign in to comment.