From 49c6d439d0d61fbd66e6cdd8ab55e8da9207824e Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Mon, 6 Nov 2023 18:07:36 +1030 Subject: [PATCH] add deepseek inference and finetuning --- .../refact_known_models/huggingface.py | 22 ++++++ refact_data_pipeline/filters_fim_v2.py | 7 +- refact_scratchpads/__init__.py | 1 + refact_scratchpads/scratchpad_hf.py | 73 +++++++++++++++++++ .../configuration/supported_models.py | 37 ++++++++++ .../finetune/modelling/flash_sa.py | 1 + 6 files changed, 138 insertions(+), 3 deletions(-) diff --git a/known_models_db/refact_known_models/huggingface.py b/known_models_db/refact_known_models/huggingface.py index 03e82729..835d1657 100644 --- a/known_models_db/refact_known_models/huggingface.py +++ b/known_models_db/refact_known_models/huggingface.py @@ -130,4 +130,26 @@ "T": 2048, "filter_caps": ["wizardlm"], }, + "deepseek-ai/deepseek-coder-1.3b-base": { + "backend": "transformers", + "model_path": "deepseek-ai/deepseek-coder-1.3b-base", + "diff_scratchpad_class": "refact_scratchpads:ScratchpadDeepSeekCoderFIM", + "chat_scratchpad_class": None, + "model_class_kwargs": { + "load_in_4bit": True, + }, + "T": 4096, + "filter_caps": ["completion", "finetune"], + }, + "deepseek-ai/deepseek-coder-6.7b-base": { + "backend": "transformers", + "model_path": "deepseek-ai/deepseek-coder-6.7b-base", + "diff_scratchpad_class": "refact_scratchpads:ScratchpadDeepSeekCoderFIM", + "chat_scratchpad_class": None, + "model_class_kwargs": { + "load_in_4bit": True, + }, + "T": 4096, + "filter_caps": ["completion", "finetune"], + }, } diff --git a/refact_data_pipeline/filters_fim_v2.py b/refact_data_pipeline/filters_fim_v2.py index 9fca4ab9..8cccb02b 100644 --- a/refact_data_pipeline/filters_fim_v2.py +++ b/refact_data_pipeline/filters_fim_v2.py @@ -214,7 +214,8 @@ def __init__( self.fim_probability = dataopts.get("fim_probability", 0.5) self.tkr_stochastic_tokens = dataopts.get("tkr_stochastic_tokens", 3) self.fim_drop_residuals = bool(dataopts.get("fim_drop_residuals", 0)) - self.random_trim_context_prob = bool(dataopts.get("random_trim_context_prob", 0.0)) + self.random_trim_context_prob = dataopts.get("random_trim_context_prob", 0.0) + self.spm_prob = dataopts.get("spm_prob", 0.5) self.debug = bool(dataopts.get("debug", 0)) self.enc = dataopts.encoding if hasattr(self.enc, "set_random_seed"): @@ -338,7 +339,7 @@ def _fim_format( middle_toks: List[int], suffix_toks: List[int], ): - if self.random.random() < 0.5: + if self.random.random() < self.spm_prob: tokens_context = [self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks mask_context = [0] + [1] * len(prefix_toks) + [0] + [1] * len(suffix_toks) else: @@ -390,7 +391,7 @@ def _fim_format( ): assert self.enc.BOS is not None # https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L380 - if self.random.random() < 0.5: + if self.random.random() < self.spm_prob: tokens = ( [self.enc.BOS, self.enc.PREFIX] + prefix_toks + [self.enc.SUFFIX] + suffix_toks diff --git a/refact_scratchpads/__init__.py b/refact_scratchpads/__init__.py index 7ccf8935..2f89d12b 100644 --- a/refact_scratchpads/__init__.py +++ b/refact_scratchpads/__init__.py @@ -4,6 +4,7 @@ from refact_scratchpads.scratchpad_hf import ScratchpadPSM from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaPSM from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaSPM +from refact_scratchpads.scratchpad_hf import ScratchpadDeepSeekCoderFIM from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceStarChat from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceWizard from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceLlama2 diff --git a/refact_scratchpads/scratchpad_hf.py b/refact_scratchpads/scratchpad_hf.py index b8683dc3..4e859c22 100644 --- a/refact_scratchpads/scratchpad_hf.py +++ b/refact_scratchpads/scratchpad_hf.py @@ -325,6 +325,79 @@ def _prompt_format(self, prefix_tokens, suffix_tokens): self._fim_middle, ] +class ScratchpadDeepSeekCoderFIM(ScratchpadHuggingfaceBase): + + def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, cursor1: int, **kwargs): + super().__init__(**kwargs) + + assert cursor0 == cursor1 + + self._cursor_file = cursor_file + self._cursor = cursor0 + self._code = sources[cursor_file] + + self._prefix: Optional[str] = None + self._suffix: Optional[str] = None + self._suffix_line0cut: Optional[str] = None + self._completion = [] + + self._tokens_produced = 0 + self._fim_prefix = self._encode_one_token("<|fim▁begin|>") + self._fim_suffix = self._encode_one_token("<|fim▁hole|>") + self._fim_middle = self._encode_one_token("<|fim▁end|>") + self._fim_eot = self._encode_one_token("<|EOT|>") + self._special_tokens.update({ + self._fim_prefix, self._fim_suffix, self._fim_middle, self._fim_eot, + }) + + def _prompt_format(self, prefix_tokens, suffix_tokens): + return [ + self._fim_prefix, *prefix_tokens, + self._fim_suffix, *suffix_tokens, + self._fim_middle, + ] + + def prompt(self, T: int): + self._prefix = self._code[:self._cursor] + # Why we need to cut the line right of the cursor? + # Example 1: + # function_call(param1, GENERATED_TONENS) + # => everything works right + # Example 2: + # function_call(param1, GENERATED_TONENS)\nMORE_TOKENS\nSOME_OTHER_CALL(OTHER_PARAM) + # ^^ but we stop here because we need single line completion + # => we have two closing parenthesis. + # self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:]) + self._suffix = self._code[self._cursor:].lstrip(" \t") + self._suffix_line0cut = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:]) + self._completion.clear() + + prefix_cut, suffix_cut = trim_context_infill( + self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens + ) + prefix_cut_tokens = self._encode_without_special_tokens(prefix_cut) + suffix_cut_tokens = self._encode_without_special_tokens(suffix_cut) + self.debuglog( + "ScratchpadFIM prompt prefix %d chars -> %d tokens, suffix %d chars -> %d tokens, T=%d max_new_tokens=%d" % + (len(prefix_cut), len(prefix_cut_tokens), len(suffix_cut), len(suffix_cut_tokens), T, self._max_tokens) + ) + prompt: List[int] = self._prompt_format(prefix_cut_tokens, suffix_cut_tokens) + self.debuglog("-"*40) + self.debuglog(self._tokenizer.decode(prompt)) + self.debuglog("-"*40) + return prompt + + def completion(self, final: bool): + assert self._prefix is not None + assert self._suffix is not None + completion = self._tokenizer.decode(self._completion) + if self.finish_reason == "eot": + # Correct stop + return {self._cursor_file: self._prefix + completion + self._suffix} + else: + # "stop-lf" or "length" or not stopped yet (empty reason), it's better to remove first line remainder + return {self._cursor_file: self._prefix + completion + self._suffix_line0cut} + class ScratchpadChatBase(ScratchpadHuggingfaceBase): diff --git a/self_hosting_machinery/finetune/configuration/supported_models.py b/self_hosting_machinery/finetune/configuration/supported_models.py index ecf1962d..c8496c70 100644 --- a/self_hosting_machinery/finetune/configuration/supported_models.py +++ b/self_hosting_machinery/finetune/configuration/supported_models.py @@ -41,6 +41,36 @@ "force_enable_checkpointing": False } +_deepseek_base = { + "lora_target_modules_mapping": { + "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + "out": ["self_attn.o_proj"], + "backproj": ["self_attn.o_proj"], + "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], + }, + "freeze_exceptions_mapping": { + "wte": ["embed_tokens"], + "lm_head": ["lm_head"], + "lora": ["lora"] + }, + "tokenizer": { + "eot_idx": 32021, # `<|EOT|>` + "padding_idx": 32018, # `` + "fim_prefix": 32016, # `<|fim▁begin|>` + "fim_middle": 32017, # `<|fim▁end|>` + "fim_suffix": 32015, # `<|fim▁hole|>` + "escape": 32013, # using `<|begin▁of▁sentence|>` token for now + }, + "train_ds_pipeline": { + "ds_opts": f"{_fim_train_ds_pipeline['ds_opts']},spm_prob=0.0", + "ds_name": _fim_train_ds_pipeline["ds_name"] + }, + "test_ds_pipeline": _fim_test_ds_pipeline, + "train_model_modifiers": [ + "flash_sa.apply_flash_mha_to_codellama_model" + ] +} + config = { "Refact/1.6B": { "lora_target_modules_mapping": { @@ -105,5 +135,12 @@ "flash_sa.apply_flash_mha_to_codellama_model" ], "force_enable_checkpointing": True + }, + + "deepseek-ai/deepseek-coder-1.3b-base": _deepseek_base, + + "deepseek-ai/deepseek-coder-6.7b-base": { + **_deepseek_base, + "force_enable_checkpointing": True } } diff --git a/self_hosting_machinery/finetune/modelling/flash_sa.py b/self_hosting_machinery/finetune/modelling/flash_sa.py index b963dcd4..d77c5c91 100644 --- a/self_hosting_machinery/finetune/modelling/flash_sa.py +++ b/self_hosting_machinery/finetune/modelling/flash_sa.py @@ -152,6 +152,7 @@ def _forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs ): from transformers.models.llama.modeling_llama import apply_rotary_pos_emb