Skip to content

Commit

Permalink
Added some fixes addressing the MambaCache initialization, specifical…
Browse files Browse the repository at this point in the history
…ly when trying to handle conv_states with mark_static_address. The fixes were recommended by gpt so needs more research, yet the models are passing as of now.
  • Loading branch information
ddilbazTT committed Jan 13, 2025
1 parent 15fb194 commit 09ebe93
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions tests/models/mamba/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,56 @@
import pytest
from tests.utils import ModelTester
import torch
import types
from transformers.models.mamba.modeling_mamba import MambaCache


def new_cache_init(self, config, max_batch_size, max_length, device, dtype):
self.max_batch_size = max_batch_size
self.max_length = max_length
self.dtype = dtype
self.device = device

batch_shape = (config.num_hidden_layers, max_batch_size)

conv_states = torch.zeros(
*batch_shape,
config.hidden_size,
config.conv_kernel - 1,
device=device,
dtype=dtype,
)
self.register_buffer("conv_states", conv_states, persistent=False)

ssm_state_shape = (
config.num_hidden_layers,
max_batch_size,
config.intermediate_size,
config.state_size,
)
ssm_states = torch.zeros(ssm_state_shape, device=device, dtype=dtype)
self.register_buffer("ssm_states", ssm_states, persistent=False)


# Replace the cache initialization
MambaCache.__init__ = new_cache_init


class ThisTester(ModelTester):
def _load_model(self):
model = MambaForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16
)

# Correctly override generate method
original_generate = model.generate

def generate_without_cache(self, **kwargs):
kwargs["use_cache"] = False
return original_generate(**kwargs)

model.generate = types.MethodType(generate_without_cache, model)

self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16
)
Expand All @@ -22,8 +65,12 @@ def _load_model(self):
def _load_inputs(self):
prompt = "Hey how are you doing?"
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
generation_config = GenerationConfig(max_new_tokens=10)
arguments = {"input_ids": input_ids, "generation_config": generation_config}
generation_config = GenerationConfig(max_new_tokens=10, use_cache=False)
arguments = {
"input_ids": input_ids,
"generation_config": generation_config,
"use_cache": False,
}
return arguments

def set_model_eval(self, model):
Expand All @@ -43,13 +90,9 @@ def set_model_eval(self, model):
"state-spaces/mamba-370m-hf",
],
)
@pytest.mark.xfail(
reason="Fails due to 'Attempt to trace forbidden callable', but we can still generate a graph"
)
def test_mamba(record_property, mode, model_name):
record_property("model_name", model_name)
record_property("mode", mode)

tester = ThisTester(model_name, mode)
results = tester.test_model()
if mode == "eval":
Expand Down

0 comments on commit 09ebe93

Please sign in to comment.