Skip to content

Commit

Permalink
more logging
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT committed Oct 16, 2023
1 parent 4380565 commit 15fa903
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 3 additions & 0 deletions self_hosting_machinery/finetune/modelling/flash_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def _forward(
return attn_output, None

if torch.cuda.get_device_capability() < (8, 0):
model.force_low_gpu_mem_mode = True
logging.warning("Flash attention is not supported on gpus with cuda capability < 8")
return

logging.warning("Applying flash attention to the model")
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))
2 changes: 2 additions & 0 deletions self_hosting_machinery/finetune/modelling/triton_flash_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,10 @@ def _forward(
return attn_output, None

if th.cuda.get_device_capability() < (8, 0):
model.force_low_gpu_mem_mode = True
logging.warning("Triton flash attention is not supported on gpus with cuda capability < 8")
return

logging.warning("Applying triton flash attention to the model")
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))
13 changes: 9 additions & 4 deletions self_hosting_machinery/finetune/scripts/script_aux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
from torchinfo import summary
from transformers import AutoTokenizer, AutoModelForCausalLM
from refact_models.lora import LoraMixin

from refact_models.lora import LoraMixin
from self_hosting_machinery.finetune.configuration import supported_models
from self_hosting_machinery.finetune.modelling.loss import masked_loss
from self_hosting_machinery.finetune.modelling.model_handling import map_model_specific_params
from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params
from self_hosting_machinery.finetune.utils import traces
from self_hosting_machinery.finetune.utils.timer import Timer

Expand Down Expand Up @@ -263,13 +263,18 @@ def _apply_model_modifiers(
path, modifier_name = modifier.rsplit('.', maxsplit=1)
mod_path = importlib.import_module(f"self_hosting_machinery.finetune.modelling.{path}")
mod = getattr(mod_path, modifier_name)
mod(model)
try:
mod(model)
except Exception as e:
logging.error(f"Applying model modifier {mod_path} wasn't successful: {e}")

def _set_low_gpu_mode(
self,
low_gpu_mode: bool
):
self.low_gpu_mem_mode = low_gpu_mode
force_low_gpu_mem_mode = hasattr(self.model, "force_low_gpu_mem_mode") and self.model.force_low_gpu_mem_mode
self.low_gpu_mem_mode = low_gpu_mode or force_low_gpu_mem_mode
logging.warning("Setting low_gpu_mem_mode={self.low_gpu_mem_mode} for the model")

if self.low_gpu_mem_mode:
self.model.gradient_checkpointing_enable()
Expand Down

0 comments on commit 15fa903

Please sign in to comment.