Skip to content

Commit

Permalink
Adding Mamba blog
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Song <[email protected]>
Co-authored-by: Peter Jun Park <[email protected]>
  • Loading branch information
Danny213123 and peterjunpark committed Jun 28, 2024
1 parent 5699f9b commit a7253d0
Show file tree
Hide file tree
Showing 7 changed files with 869 additions and 0 deletions.
373 changes: 373 additions & 0 deletions blogs/artificial-intelligence/mamba/LICENSE.txt

Large diffs are not rendered by default.

402 changes: 402 additions & 0 deletions blogs/artificial-intelligence/mamba/README.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2023, Tri Dao, Albert Gu.

import argparse
import time
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

parser = argparse.ArgumentParser(description="Generation benchmarking")
parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--promptlen", type=int, default=100)
parser.add_argument("--genlen", type=int, default=100)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--topk", type=int, default=1)
parser.add_argument("--topp", type=float, default=1.0)
parser.add_argument("--minp", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--attn_implementation", type=str, default="sdpa")
args = parser.parse_args()

repeats = 3
device = "cuda"
dtype = torch.float16
torch.cuda.reset_peak_memory_stats()

print(f"Loading model {args.model_name}")
is_mamba = args.model_name.startswith("state-spaces/mamba")
if is_mamba:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype, attn_implementation=args.attn_implementation)

print(f"{args.model_name} model configuration attn_implementation:{args.attn_implementation}")
model.eval()
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

torch.random.manual_seed(0)
if args.prompt is None:
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
else:
tokens = tokenizer(args.prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + args.genlen

if is_mamba:
fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
min_p=args.minp,
repetition_penalty=args.repetition_penalty,
)
else:
fn = lambda: model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_length=max_length,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)
out = fn()
if args.prompt is not None:
print(tokenizer.batch_decode(out.sequences.tolist()))

torch.cuda.synchronize()
start = time.time()
for _ in range(repeats):
fn()
torch.cuda.synchronize()
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0 * 1024.0)
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
print(f"prompt processing + decoding time: {(time.time() - start) / repeats :.2f}s")
print(f"memory used: {memory_used:.0f}GB")

0 comments on commit a7253d0

Please sign in to comment.