Skip to content

Commit

Permalink
add flash attention to codellama
Browse files Browse the repository at this point in the history
add inference fixes for codellama
  • Loading branch information
JegernOUTT authored and olegklimov committed Nov 2, 2023
1 parent 687a524 commit 122833a
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 29 deletions.
6 changes: 2 additions & 4 deletions known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,9 @@
"codellama/7b": {
"backend": "transformers",
"model_path": "TheBloke/CodeLlama-7B-fp16",
"diff_scratchpad_class": "refact_scratchpads:ScratchpadCodeLlama",
"diff_scratchpad_class": "refact_scratchpads:ScratchpadCodeLlamaSPM",
"chat_scratchpad_class": None,
"model_class_kwargs": {
"load_in_8bit": True,
},
"model_class_kwargs": {},
"required_memory_mb": 14000,
"T": 2048,
"filter_caps": ["completion", "finetune"],
Expand Down
10 changes: 5 additions & 5 deletions refact_data_pipeline/filters_fim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,11 @@ def _fim_format(
mask_context = [0] + [1] * len(suffix_toks) + [0] + [1] * len(prefix_toks)

middle_mask = [1] * len(middle_toks)
if self.debug:
print(f'splitter: {splitter}, middle_size: {len(middle)}, middle: {middle}')
print(termcolor.colored(self.enc.decode(prefix_toks), "red"), end='')
print(termcolor.colored(self.enc.decode(middle_toks), "green"), end='')
print(termcolor.colored(self.enc.decode(suffix_toks), "red"))
# if self.debug:
# print(f'splitter: {splitter}, middle_size: {len(middle)}, middle: {middle}')
# print(termcolor.colored(self.enc.decode(prefix_toks), "red"), end='')
# print(termcolor.colored(self.enc.decode(middle_toks), "green"), end='')
# print(termcolor.colored(self.enc.decode(suffix_toks), "red"))

tokens = tokens_context + [self.enc.INFILL] + middle_toks + [self.enc.EOT]
mask = mask_context + [0] + middle_mask + [1]
Expand Down
3 changes: 2 additions & 1 deletion refact_scratchpads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceCompletion
from refact_scratchpads.scratchpad_hf import ScratchpadSPM
from refact_scratchpads.scratchpad_hf import ScratchpadPSM
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlama
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaPSM
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlamaSPM
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceStarChat
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceWizard
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceLlama2
Expand Down
72 changes: 57 additions & 15 deletions refact_scratchpads/scratchpad_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _prompt_format(self, prefix_tokens, suffix_tokens):
]


class ScratchpadCodeLlama(ScratchpadHuggingfaceBase):
class ScratchpadCodeLlamaBase(ScratchpadHuggingfaceBase):

def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, cursor1: int, **kwargs):
super().__init__(**kwargs)
Expand All @@ -248,40 +248,82 @@ def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, curs

self._prefix: Optional[str] = None
self._suffix: Optional[str] = None
self._suffix_line0cut: Optional[str] = None
self._completion = []

self._tokens_produced = 0
self._bos = self._encode_one_token("<s>")
self._fim_prefix = self._encode_one_token("<PRE>")
self._fim_suffix = self._encode_one_token("<SUF>")
self._fim_middle = self._encode_one_token("<MID>")
self._fim_eot = self._encode_one_token("<EOT>")
self._special_tokens.update({
self._fim_prefix, self._fim_suffix, self._fim_middle, self._fim_eot,
self._bos, self._fim_prefix, self._fim_suffix, self._fim_middle, self._fim_eot,
})

def _prompt_format(self, prefix_tokens, suffix_tokens):
raise NotImplementedError()

def prompt(self, T: int):
self._prefix = self._code[:self._cursor]
self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
# Why we need to cut the line right of the cursor?
# Example 1:
# function_call(param1, GENERATED_TONENS<EOF>)
# => everything works right
# Example 2:
# function_call(param1, GENERATED_TONENS)\nMORE_TOKENS\nSOME_OTHER_CALL(OTHER_PARAM<EOF>)
# ^^ but we stop here because we need single line completion
# => we have two closing parenthesis.
# self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
self._suffix = self._code[self._cursor:].lstrip(" \t")
self._suffix_line0cut = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
self._completion.clear()

prefix_cut, suffix_cut = trim_context_infill(
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens)
prompt: List[int] = [
self._eos_token,
self._fim_prefix,
*self._tokenizer.encode(prefix_cut),
self._fim_suffix,
*self._tokenizer.encode(suffix_cut),
self._fim_middle,
]
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens
)
prefix_cut_tokens = self._encode_without_special_tokens(prefix_cut)
suffix_cut_tokens = self._encode_without_special_tokens(suffix_cut)
self.debuglog(
"ScratchpadFIM prompt prefix %d chars -> %d tokens, suffix %d chars -> %d tokens, T=%d max_new_tokens=%d" %
(len(prefix_cut), len(prefix_cut_tokens), len(suffix_cut), len(suffix_cut_tokens), T, self._max_tokens)
)
prompt: List[int] = self._prompt_format(prefix_cut_tokens, suffix_cut_tokens)
self.debuglog("-"*40)
self.debuglog(self._tokenizer.decode(prompt))
self.debuglog("-"*40)
return prompt

def completion(self, final: bool):
assert self._prefix is not None
assert self._suffix is not None
return {
self._cursor_file: self._prefix + self._tokenizer.decode(self._completion) + self._suffix,
}
completion = self._tokenizer.decode(self._completion)
if self.finish_reason == "eot":
# Correct stop
return {self._cursor_file: self._prefix + completion + self._suffix}
else:
# "stop-lf" or "length" or not stopped yet (empty reason), it's better to remove first line remainder
return {self._cursor_file: self._prefix + completion + self._suffix_line0cut}


class ScratchpadCodeLlamaSPM(ScratchpadCodeLlamaBase):

def _prompt_format(self, prefix_tokens, suffix_tokens):
return [
self._bos, self._fim_prefix, self._fim_suffix,
*suffix_tokens,
self._fim_middle, *prefix_tokens,
]


class ScratchpadCodeLlamaPSM(ScratchpadCodeLlamaBase):

def _prompt_format(self, prefix_tokens, suffix_tokens):
return [
self._bos, self._fim_prefix, *prefix_tokens,
self._fim_suffix, *suffix_tokens,
self._fim_middle,
]


class ScratchpadChatBase(ScratchpadHuggingfaceBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@
**_fim_test_ds_pipeline,
"ds_name": "CodeLLamaFIMDataset"
},
"train_model_modifiers": [],
"train_model_modifiers": [
"flash_sa.apply_flash_mha_to_codellama_model"
],
"force_enable_checkpointing": True
}
}
46 changes: 46 additions & 0 deletions self_hosting_machinery/finetune/modelling/flash_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,49 @@ def _forward(
logging.warning("Applying flash attention to the model")
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))


def apply_flash_mha_to_codellama_model(model):
if not _prerequisites_are_ok(model):
return

from flash_attn import flash_attn_func

def _forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = einops.rearrange(q, "b t (h d) -> b h t d", h=self.num_heads)
k = einops.rearrange(k, "b t (h d) -> b h t d", h=self.num_key_value_heads)
v = einops.rearrange(v, "b t (h d) -> b t h d", h=self.num_key_value_heads)

cos, sin = self.rotary_emb(q, seq_len=k.shape[-2])
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

q = einops.rearrange(q, "b h t d -> b t h d")
k = einops.rearrange(k, "b h t d -> b t h d")

attn_output = flash_attn_func(
q, k, v, softmax_scale=self.head_dim ** -0.5, causal=True
)

attn_output = einops.rearrange(attn_output, "b t h d -> b t (h d)")
attn_output = self.o_proj(attn_output)
return attn_output, None, None

if torch.cuda.get_device_capability() < (8, 0):
model.force_low_gpu_mem_mode = True
logging.warning("Flash attention is not supported on gpus with cuda capability < 8")
return

logging.warning("Applying flash attention to the model")
for layer in model.base_model.layers:
layer.self_attn.forward = _forward.__get__(layer.self_attn, type(layer.self_attn))
5 changes: 2 additions & 3 deletions self_hosting_machinery/finetune/scripts/aux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def __init__(
)
self.use_deepspeed = True

if debug:
logging.info("1 gpumem_p0 %0.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
summary(self.model, depth=4, col_names=['num_params', 'params_percent', 'trainable'])
logging.info("Allocated memory: %0.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
summary(self.model, depth=4, col_names=['num_params', 'params_percent', 'trainable'])

self.loss_fn = partial(
masked_loss,
Expand Down

0 comments on commit 122833a

Please sign in to comment.