From 8e2cc943eb888b6133e2c0eb1c4a62f79c71f980 Mon Sep 17 00:00:00 2001 From: marswen Date: Fri, 29 Mar 2024 11:03:13 +0800 Subject: [PATCH 1/2] support llama models --- flexgen/flex_llama.py | 447 +++++++++++++++++++++++++++++++++++++ flexgen/llama_config.py | 110 +++++++++ flexgen/pytorch_backend.py | 307 +++++++++++++++++++++++++ 3 files changed, 864 insertions(+) create mode 100644 flexgen/flex_llama.py create mode 100644 flexgen/llama_config.py diff --git a/flexgen/flex_llama.py b/flexgen/flex_llama.py new file mode 100644 index 00000000..00827d47 --- /dev/null +++ b/flexgen/flex_llama.py @@ -0,0 +1,447 @@ +""" +Usage: +python3 -m flexgen.flex_llama --model meta-llama/Llama-2-7b-chat-hf --gpu-batch-size 32 --percent 100 0 100 0 100 0 +""" +import os +import torch +import argparse +from typing import Union +from transformers import AutoTokenizer +from flexgen.compression import CompressionConfig +from flexgen.llama_config import LlamaConfig, get_llama_config, download_llama_weights +from flexgen.pytorch_backend import LlamaTorchDevice, TorchDisk, TorchMixedDevice, fix_recursive_import +from flexgen.flex_opt import (Policy, init_weight_list, InputEmbed, OutputEmbed, SelfAttention, MLP, + TransformerLayer, OptLM, get_filename, get_test_inputs) +from flexgen.timer import timers +from flexgen.utils import (ExecutionEnv, GB, ValueHolder, + array_1d, array_2d, str2bool, project_decode_latency, write_benchmark_log) + +fix_recursive_import() + +DUMMY_WEIGHT = "_DUMMY_" # Use dummy weights for benchmark purposes + + +class LlamaInputEmbed(InputEmbed): + def __init__(self, config, env, policy): + super().__init__(config, env, policy) + + def init_weight(self, weight_home, path): + v, h, dtype = (self.config.vocab_size, self.config.input_dim, + self.config.dtype) + path = os.path.join(path, "") + weight_specs = [ + # w_token + ((v, h), dtype, path + "embed_tokens.weight"), + ] + weights = init_weight_list(weight_specs, self.policy, self.env) + + weight_home.store(weights) + + def load_weight(self, weight_home, weight_read_buf, k): + w_token, = weight_home.val + if k == 0: + dst = self.weight_load_dst + weight_read_buf.store((w_token.smart_copy(dst),)) + + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, + cache_write_buf, i, k): + # Compute input embedding + donate = [False] * 3 + h, donate[0] = hidden.val, True + mask, donate[1] = attention_mask.val.smart_copy(self.compute) + + if k == self.policy.num_gpu_batches - 1: + # Clear the weight_read_buf if it is the last gpu batch + (w_token, donate[2]), = weight_read_buf.pop() + else: + (w_token, _), = weight_read_buf.val + + h = self.compute.llama_input_embed(h, mask, + w_token, self.config.pad_token_id, donate) + hidden.val = h + + +class LlamaOutputEmbed(OutputEmbed): + def __init__(self, config, env, policy): + super().__init__(config, env, policy) + + def init_weight(self, weight_home, path): + v, h, dtype = (self.config.vocab_size, self.config.input_dim, + self.config.dtype) + path = os.path.join(path, "") + weight_specs = [ + # w_ln + ((h,), dtype, path + "norm.weight"), + # w_token + ((v, h), dtype, path + "lm_head.weight"), + ] + weights = init_weight_list(weight_specs, self.policy, self.env) + + weight_home.store(weights) + + def load_weight(self, weight_home, weight_read_buf, k): + w_ln, w_token = weight_home.val + if k == 0: + dst1 = self.weight_load_dst + dst2 = self.compute + weight_read_buf.store((w_ln.smart_copy(dst2), w_token.smart_copy(dst1))) + + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, + cache_write_buf, i, k): + donate = [False] * 3 + h, donate[0] = hidden.val, True + + if k == self.policy.num_gpu_batches - 1: + # Clear the weight_read_buf if it is the last gpu batch + (w_ln, donate[1]), (w_token, donate[2]) = weight_read_buf.pop() + else: + (w_ln, _), (w_token, _) = weight_read_buf.val + + h = self.compute.llama_output_embed(h, w_ln, w_token, self.config.rms_norm_eps, donate, + self.task.do_sample, self.task.temperature) + hidden.val = h + + +class LlamaSelfAttention(SelfAttention): + def __init__(self, config, env, policy, layer_id): + super().__init__(config, env, policy, layer_id) + + def init_weight(self, weight_home, path): + h, n_head, n_kv_head, dtype = (self.config.input_dim, self.config.n_head, self.config.num_key_value_heads, self.config.dtype) + head_dim = h // n_head + path = os.path.join(os.path.join(path, f"layers.{self.layer_id}.")) + weight_specs = [ + # w_ln + ((h,), dtype, path + "input_layernorm.weight"), + # w_q + ((h, n_head*head_dim), dtype, path + "self_attn.q_proj.weight"), + # w_k + ((n_kv_head*head_dim, h), dtype, path + "self_attn.k_proj.weight"), + # w_v + ((n_kv_head*head_dim, h), dtype, path + "self_attn.v_proj.weight"), + # w_re + ((head_dim//2,), dtype, path + "self_attn.rotary_emb.inv_freq"), + # w_o + ((n_head*head_dim, h), dtype, path + "self_attn.o_proj.weight"), + ] + weights = init_weight_list(weight_specs, self.policy, self.env) + weight_home.store(weights) + + def load_weight(self, weight_home, weight_read_buf, k): + w_ln, w_q, w_k, w_v, w_re, w_o = weight_home.val + if k == 0: + dst1 = self.weight_load_dst + dst2 = self.compute + weight_read_buf.store(( + w_ln.smart_copy(dst2), + w_q.smart_copy(dst1), + w_k.smart_copy(dst1), + w_v.smart_copy(dst1), + w_re.smart_copy(dst1), + w_o.smart_copy(dst1))) + + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, + cache_write_buf, i, k): + n_head = self.config.n_head + n_kv_head = self.config.num_key_value_heads + + donate = [False] * 10 + h, donate[0] = hidden.val, True + + if k == self.policy.num_gpu_batches - 1: + # Clear the weight_read_buf if it is the last gpu batch + ((w_ln, donate[2]), (w_q, donate[3]), (w_k, donate[4]), (w_v, donate[5]), + (w_re, donate[6]), (w_o, donate[7])) = weight_read_buf.pop() + else: + ((w_ln, _), (w_q, _), (w_k, _), (w_v, _), + (w_re, _), (w_o, _)) = weight_read_buf.val + + if i == 0: # prefill + mask, donate[1] = attention_mask.val.smart_copy(self.compute) + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 + h, new_k_cache, new_v_cache = self.compute.llama_mha(h, position_ids, mask, w_ln, + w_q, w_k, w_v, w_re, w_o, n_head, n_kv_head, donate, self.config.rms_norm_eps, + self.policy.compress_cache, self.policy.comp_cache_config) + cache_write_buf.store((new_k_cache, new_v_cache)) + else: # decoding + mask, donate[1] = attention_mask.val.smart_copy(self.attention_compute) + (k_cache, donate[8]), (v_cache, donate[9]) = cache_read_buf.pop() + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 + position_ids = position_ids[:, -h.shape[1]].unsqueeze(1) + h, new_k_cache, new_v_cache = self.compute.llama_mha_gen(h, position_ids, mask, w_ln, + w_q, w_k, w_v, w_re, w_o, self.config.rms_norm_eps, n_head, n_kv_head, + k_cache, v_cache, donate, self.policy.attn_sparsity, + self.policy.compress_cache, self.policy.comp_cache_config) + cache_write_buf.store((new_k_cache, new_v_cache)) + + hidden.val = h + + +class LlamaMLP(MLP): + def __init__(self, config, env, policy, layer_id): + super().__init__(config, env, policy, layer_id) + + def init_weight(self, weight_home, path): + h, intermediate, dtype = (self.config.input_dim, self.config.intermediate_size, self.config.dtype) + path = os.path.join(os.path.join(path, f"layers.{self.layer_id}.")) + weight_specs = [ + # w_ln + ((h,), dtype, path + "post_attention_layernorm.weight"), + # w_g + ((intermediate, h), dtype, path + "mlp.gate_proj.weight"), + # w_u + ((intermediate, h), dtype, path + "mlp.up_proj.weight"), + # w_d + ((h, intermediate), dtype, path + "mlp.down_proj.weight"), + ] + weights = init_weight_list(weight_specs, self.policy, self.env) + weight_home.store(weights) + + def load_weight(self, weight_home, weight_read_buf, k): + w_ln, w_g, w_u, w_d = weight_home.val + if k == 0: + dst1 = self.weight_load_dst + dst2 = self.compute + weight_read_buf.store(( + w_ln.smart_copy(dst2), + w_g.smart_copy(dst1), + w_u.smart_copy(dst1), + w_d.smart_copy(dst1))) + + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, + cache_write_buf, i, k): + donate = [False] * 5 + h, donate[0] = hidden.val, True + + if k == self.policy.num_gpu_batches - 1: + # Clear the weight_read_buf if it is the last gpu batch + ((w_ln, donate[1]), (w_g, donate[2]), (w_u, donate[3]), + (w_d, donate[4])) = weight_read_buf.pop() + else: + ((w_ln, _), (w_g, _), (w_u, _), (w_d, _)) = weight_read_buf.val + + h = self.compute.llama_mlp(h, w_ln, w_g, w_u, w_d, self.config.rms_norm_eps, donate) + hidden.val = h + + +class LlamaTransformerLayer(TransformerLayer): + def __init__(self, config, env, policy, i): + self.attention = LlamaSelfAttention(config, env, policy, i) + self.mlp = LlamaMLP(config, env, policy, i) + self.policy = policy + self.compute = self.attention.compute + + +class LlamaLM(OptLM): + def __init__(self, + config: Union[str, LlamaConfig], + env: ExecutionEnv, + path: str, + policy: Policy): + if isinstance(config, str): + config = get_llama_config(config) + self.config = config + self.env = env + self.path = path + self.policy = policy + self.num_gpu_batches = policy.num_gpu_batches + + layers = [] + layers.append(LlamaInputEmbed(self.config, self.env, self.policy)) + for i in range(self.config.num_hidden_layers): + if policy.sep_layer: + layers.append(LlamaSelfAttention(self.config, self.env, self.policy, i)) + layers.append(LlamaMLP(self.config, self.env, self.policy, i)) + else: + layers.append(LlamaTransformerLayer(self.config, self.env, self.policy, i)) + layers.append(LlamaOutputEmbed(self.config, self.env, self.policy)) + self.layers = layers + self.num_layers = len(layers) + + if self.policy.act_gpu_percent == 100: + self.act_home = self.env.gpu + elif self.policy.act_cpu_percent == 100: + self.act_home = self.env.cpu + elif self.policy.act_disk_percent == 100: + self.act_home = self.env.disk + else: + raise NotImplementedError() + + # CUDA streams + self.load_weight_stream = torch.cuda.Stream() + self.load_cache_stream = torch.cuda.Stream() + self.store_cache_stream = torch.cuda.Stream() + + # Intermediate tensors + # The following buffers store values used + # for the i-th token, j-th layer, k-th gpu batch. + num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches + + # cache[j][k] + self.cache_home = array_2d(num_layers, num_gpu_batches, ValueHolder) + self.cache_read_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) + self.cache_write_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) + # weight[j] + self.weight_read_buf = array_1d(num_layers, ValueHolder) + # attention_mask[k] + self.attention_mask = array_1d(num_gpu_batches, ValueHolder) + + self.task = None + self.init_all_weights() + + def init_weight(self, j): + expanded_path = os.path.abspath(os.path.expanduser( + os.path.join(self.path, f"{self.config.name}-np"))) + check_path = os.path.join(expanded_path, "embed_tokens.weight") + if not os.path.exists(check_path) and DUMMY_WEIGHT not in check_path: + download_llama_weights(self.config.name, self.path, self.config.hf_token) + + self.layers[j].init_weight(self.weight_home[j], expanded_path) + + +def run_flexgen(args): + print(f": args.model: {args.model}") + tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.hf_token, padding_side="left") + tokenizer.pad_token_id = tokenizer.eos_token_id + num_prompts = args.num_gpu_batches * args.gpu_batch_size + prompt_len, gen_len, cut_gen_len = args.prompt_len, args.gen_len, args.cut_gen_len + + # Task and policy + warmup_inputs = get_test_inputs(32, num_prompts, tokenizer) + inputs = get_test_inputs(prompt_len, num_prompts, tokenizer) + + gpu = LlamaTorchDevice("cuda:0") + cpu = LlamaTorchDevice("cpu") + disk = TorchDisk(args.offload_dir) + env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + + policy = Policy(args.gpu_batch_size, args.num_gpu_batches, + args.percent[0], args.percent[1], + args.percent[2], args.percent[3], + args.percent[4], args.percent[5], + args.overlap, args.sep_layer, args.pin_weight, + args.cpu_cache_compute, args.attn_sparsity, + args.compress_weight, + CompressionConfig(num_bits=4, group_size=64, + group_dim=0, symmetric=False), + args.compress_cache, + CompressionConfig(num_bits=4, group_size=64, + group_dim=2, symmetric=False)) + assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented" + + llama_config = get_llama_config(args.model, hf_token=args.hf_token, pad_token_id=tokenizer.eos_token_id) + cache_size = llama_config.cache_bytes(num_prompts, prompt_len + gen_len) + hidden_size = llama_config.hidden_bytes(num_prompts, prompt_len + gen_len) + print(f"model size: {llama_config.model_bytes()/GB:.3f} GB, " + f"cache size: {cache_size/GB:.3f} GB, " + f"hidden size (prefill): {hidden_size/GB:.3f} GB") + + print("init weight...") + model = LlamaLM(llama_config, env, args.path, policy) + + try: + print("warmup - generate") + output_ids = model.generate( + warmup_inputs, max_new_tokens=1, verbose=args.verbose) + + print("benchmark - generate") + timers("generate").reset() + output_ids = model.generate( + inputs, max_new_tokens=args.gen_len, + debug_mode=args.debug_mode, cut_gen_len=cut_gen_len, verbose=args.verbose) + costs = timers("generate").costs + finally: + env.close_copy_threads() + + # Log output + prefill_latency = costs[0] + prefill_throughput = num_prompts * prompt_len / prefill_latency + if cut_gen_len: # project latency of cut_gen_len to gen_len + decode_latency = project_decode_latency(costs, prompt_len, gen_len) + else: + decode_latency = sum(costs[1:]) + decode_throughput = num_prompts * (gen_len - 1) / max(decode_latency, 1e-10) + num_generated_tokens = num_prompts * gen_len + total_latency = prefill_latency + decode_latency + total_throughput = num_generated_tokens / total_latency + _, gpu_peak_mem = gpu.mem_stats() + _, cpu_peak_mem = cpu.mem_stats() + + if DUMMY_WEIGHT not in args.path: + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + show_str = "Outputs:\n" + 70 * '-' + "\n" + for i in [0, len(outputs)-1]: + show_str += f"{i}: {outputs[i]}\n" + show_str += "-" * 70 + "\n" + if args.verbose >= 2: + print(show_str) + + gpu.print_stats() + cpu.print_stats() + projected = bool(args.debug_mode or cut_gen_len) + + if args.log_file == "auto": + filename = get_filename(args) + ".log" + else: + filename = args.log_file + + log_str = write_benchmark_log(filename, + llama_config.model_bytes(), cache_size, hidden_size, + gpu_peak_mem, projected, prefill_latency, prefill_throughput, + decode_latency, decode_throughput, total_latency, total_throughput) + if args.verbose >= 1: + print(log_str) + + +def add_parser_arguments(parser): + parser.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-chat-hf", + help="The model name.") + parser.add_argument("--hf-token", type=str, + help="The huggingface token for accessing gated repo.") + parser.add_argument("--path", type=str, default="~/llama_weights", + help="The path to the model weights. If there are no cached weights, " + "FlexGen will automatically download them from HuggingFace.") + parser.add_argument("--offload-dir", type=str, default="~/flexgen_offload_dir", + help="The directory to offload tensors. ") + parser.add_argument("--prompt-len", type=int, default=512) + parser.add_argument("--gen-len", type=int, default=32) + parser.add_argument("--cut-gen-len", type=int, + help="Cut generation length for fast debugging.") + parser.add_argument("--debug-mode", type=str, + choices=["fewer_batch", "breakdown"]) + parser.add_argument("--gpu-batch-size", type=int, default=4) + parser.add_argument("--num-gpu-batches", type=int, default=1) + parser.add_argument("--percent", nargs="+", type=int, + default=[100, 0, 100, 0, 100, 0], + help="Six numbers. They are " + "the percentage of weight on GPU, " + "the percentage of weight on CPU, " + "the percentage of attention cache on GPU, " + "the percentage of attention cache on CPU, " + "the percentage of activations on GPU, " + "the percentage of activations on CPU") + parser.add_argument("--sep-layer", type=str2bool, nargs='?', + const=True, default=True) + parser.add_argument("--pin-weight", type=str2bool, nargs="?", + const=True, default=True) + parser.add_argument("--cpu-cache-compute", action="store_true") + parser.add_argument("--attn-sparsity", type=float, default=1.0) + parser.add_argument("--compress-weight", action="store_true", + help="Whether to compress weight.") + parser.add_argument("--compress-cache", action="store_true", + help="Whether to compress cache.") + parser.add_argument("--log-file", type=str, default="auto") + parser.add_argument("--no-log", action="store_true") + parser.add_argument("--verbose", type=int, default=2) + parser.add_argument("--overlap", type=str2bool, nargs='?', + const=True, default=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_parser_arguments(parser) + args = parser.parse_args() + + assert len(args.percent) == 6 + + run_flexgen(args) diff --git a/flexgen/llama_config.py b/flexgen/llama_config.py new file mode 100644 index 00000000..7e76b79e --- /dev/null +++ b/flexgen/llama_config.py @@ -0,0 +1,110 @@ +""" +The Llama model configurations and weight downloading utilities. + +adopted from opt_config.py +""" + +import dataclasses +import glob +import os +import numpy as np +from tqdm import tqdm + + +@dataclasses.dataclass(frozen=True) +class LlamaConfig: + name: str = "Llama-2-7b-hf" + hf_token: str = '' + hidden_act: str = "silu" + input_dim: int = 4096 + initializer_range: float = 0.02 + intermediate_size: int = 11008 + max_position_embeddings: int = 4096 + n_head: int = 32 + num_hidden_layers: int = 32 + num_key_value_heads: int = 32 + rms_norm_eps: float = 1e-05 + dtype: type = np.float16 + pad_token_id: int = 2 + vocab_size: int = 32000 + + def model_bytes(self): + h = self.input_dim + intermediate = self.intermediate_size + n_head = self.n_head + head_dim = h // n_head + return 2 * (self.vocab_size * h + + self.num_hidden_layers * ( + # self-attention + 3 * h * h + h * h + head_dim // 2 + + # mlp + 3 * h * intermediate + + # layer norm + 2 * h) + + # head + h + self.vocab_size * h) + + def cache_bytes(self, batch_size, seq_len): + return 2 * batch_size * seq_len * self.num_hidden_layers * self.input_dim * 2 + + def hidden_bytes(self, batch_size, seq_len): + return batch_size * seq_len * self.input_dim * 2 + + +def get_llama_config(name, **kwargs): + if "/" in name: + name = name.split("/")[1] + + if "-chat" in name: + arch_name = name.replace("-chat", "") + else: + arch_name = name + + if arch_name == "Llama-2-7b-hf": + config = LlamaConfig(name=name, hf_token=kwargs.get('hf_token'), + input_dim=4096, intermediate_size=11008, n_head=32, + num_hidden_layers=32, num_key_value_heads=32 + ) + elif arch_name == "Llama-2-13b-hf": + config = LlamaConfig(name=name, hf_token=kwargs.get('hf_token'), + input_dim=5120, intermediate_size=13824, n_head=40, + num_hidden_layers=40, num_key_value_heads=40 + ) + elif arch_name == "Llama-2-70b-hf": + config = LlamaConfig(name=name, hf_token=kwargs.get('hf_token'), + input_dim=8192, intermediate_size=28672, n_head=64, + num_hidden_layers=80, num_key_value_heads=8 + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + + +def download_llama_weights(model_name, path, hf_token): + from huggingface_hub import snapshot_download + import torch + + print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " + f"The downloading and cpu loading can take dozens of minutes. " + f"If it seems to get stuck, you can monitor the progress by " + f"checking the memory usage of this process.") + + hf_model_name = "meta-llama/" + model_name + + folder = snapshot_download(hf_model_name, allow_patterns="*.bin", token=hf_token) + bin_files = glob.glob(os.path.join(folder, "*.bin")) + + if "/" in model_name: + model_name = model_name.split("/")[1] + path = os.path.join(path, f"{model_name}-np") + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(path, exist_ok=True) + + for bin_file in tqdm(bin_files, desc="Convert format"): + state = torch.load(bin_file, map_location='cuda:0') + for name, param in tqdm(state.items(), leave=False): + name = name.replace("model.", "") + param_path = os.path.join(path, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) diff --git a/flexgen/pytorch_backend.py b/flexgen/pytorch_backend.py index 7f341849..5e8c38ea 100644 --- a/flexgen/pytorch_backend.py +++ b/flexgen/pytorch_backend.py @@ -904,3 +904,310 @@ def copy_worker_func(queue, cuda_id): dst_data.copy_(src_data) queue.task_done() + + +def rms_norm(input, weight, eps) -> torch.Tensor: + input_dtype = input.dtype + hidden_states = input.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + return weight * hidden_states.to(input_dtype) + + +def rotary_embedding(x, inv_freq, seq_len): + t = torch.arange(seq_len, device=x.device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq.to(x.device)) + emb = torch.cat((freqs, freqs), dim=-1) + return ( + emb.cos().to(x.dtype)[:seq_len].to(dtype=x.dtype), + emb.sin().to(x.dtype)[:seq_len].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) + + +class LlamaTorchDevice(TorchDevice): + + def llama_input_embed(self, inputs, attention_mask, w_token, pad_token_id, donate): + # decompress weights + if w_token.device.device_type == DeviceType.COMPRESSED: + w_token = w_token.device.decompress(w_token) + + token_ids = inputs.data + if donate[0]: inputs.delete() + if donate[1]: attention_mask.delete() + + # token embedding + token_embed = F.embedding(token_ids, w_token.data, pad_token_id) + + return TorchTensor.create_from_torch(token_embed, self) + + def llama_output_embed(self, inputs, w_ln, w_token, eps, donate, do_sample, temperature): + # decompress weights + if w_token.device.device_type == DeviceType.COMPRESSED: + w_token = w_token.device.decompress(w_token) + + hidden = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + if donate[0]: inputs.delete() + + # output embedding + logits = F.linear(hidden, w_token.data) + last_token_logits = logits[:,-1,:] + + if do_sample and not temperature < 1e-5: + probs = torch.softmax(last_token_logits / temperature, dim=-1) + ids = torch.multinomial(probs, num_samples=1) + else: + ids = last_token_logits.argmax(dim=1, keepdim=True) + return TorchTensor.create_from_torch(ids, self) + + def llama_mha(self, inputs, position_ids, attention_mask, w_ln, w_q, w_k, w_v, + w_re, w_out, n_head, n_kv_head, donate, eps, compress_cache, comp_config): + """Multi-head attention (prefill phase).""" + # decompress weights + if w_q.device.device_type == DeviceType.COMPRESSED: + w_q = w_q.device.decompress(w_q) + w_k = w_k.device.decompress(w_k) + w_v = w_v.device.decompress(w_v) + w_re = w_re.device.decompress(w_re) + w_out = w_out.device.decompress(w_out) + + b, s, h = inputs.shape + head_dim = h // n_head + scaling = head_dim ** -0.5 + + hidden = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + + # shape: (b, s, h) + q = F.linear(hidden, w_q.data) * scaling + k = F.linear(hidden, w_k.data) + v = F.linear(hidden, w_v.data) + # shape: (b, s, n_head, head_dim) + q = q.view(b, s, n_head, head_dim) + k = k.view(b, s, n_kv_head, head_dim) + v = v.view(b, s, n_kv_head, head_dim) + + kv_seq_len = k.shape[-3] + cos, sin = rotary_embedding(v, w_re.data, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + n_kv_groups = n_head // n_kv_head + k = repeat_kv(k, n_kv_groups) + v = repeat_kv(v, n_kv_groups) + + # shape: (b * n_head, s, head_dim) + q = q.permute(0, 2, 1, 3).reshape(b * n_head, s, head_dim) + # shape: (b * n_head, head_dim, s) + k = k.permute(0, 2, 3, 1).reshape(b * n_head, head_dim, s) + # shape: (b * n_head, s, head_dim) + v = v.permute(0, 2, 1, 3).reshape(b * n_head, s, head_dim) + + # shape: (b * n_head, s, s) + attn_weights = torch.bmm(q, k) + + # shape: (b, 1, s, s) + idx = torch.arange(s, device=self.dev) + causal_mask = (idx <= idx.view(s, 1)).view(1, 1, s, s) + mask = attention_mask.data.view(b, 1, 1, s) & causal_mask + + # shape: (b, n_head, s, s) + attn_weights = attn_weights.view(b, n_head, s, s) + attn_weights = torch.where(mask, attn_weights, -1e4) + attn_weights = attn_weights.view(b * n_head, s, s) + attn_weights = F.softmax(attn_weights, dim=2) + # shape: (b, n_head, s, head_dim) + value = torch.bmm(attn_weights, v).view(b, n_head, s, head_dim) + # shape: (b, s, h) + value = value.transpose(1, 2).reshape(b, s, h) + value = F.linear(value, w_out.data) + + value.add_(inputs.data) + + if donate[0]: inputs.delete() + if donate[1]: attention_mask.delete() + + # (s, b * n_head, head_dim) + k = k.permute(2, 0, 1) + v = v.permute(1, 0, 2) + + if compress_cache: + k = self.compressed_device.compress(k, comp_config) + v = self.compressed_device.compress(v, comp_config) + else: + k = TorchTensor.create_from_torch(k, self) + v = TorchTensor.create_from_torch(v, self) + + return TorchTensor.create_from_torch(value, self), k, v + + def llama_mha_gen(self, inputs, position_ids, attention_mask, w_ln, w_q, w_k, w_v, + w_re, w_out, eps, n_head, n_kv_head, k_cache, v_cache, donate, + attn_sparsity, compress_cache, comp_config): + """Multi-head attention (decoding phase).""" + # decompress weights + if w_q.device.device_type == DeviceType.COMPRESSED: + w_q = w_q.device.decompress(w_q) + w_k = w_k.device.decompress(w_k) + w_v = w_v.device.decompress(w_v) + w_re = w_re.device.decompress(w_re) + w_out = w_out.device.decompress(w_out) + + b, tgt_s, h = inputs.shape + src_s = attention_mask.shape[1] + head_dim = h // n_head + scaling = head_dim ** -0.5 + + hidden = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + + # shape: (b, 1, h) + q = F.linear(hidden, w_q.data) * scaling + k = F.linear(hidden, w_k.data) + v = F.linear(hidden, w_v.data) + # shape: (b, 1, n_head, head_dim) + q = q.view(b, tgt_s, n_head, head_dim) + k = k.view(b, tgt_s, n_kv_head, head_dim) + v = v.view(b, tgt_s, n_kv_head, head_dim) + + cos, sin = rotary_embedding(v, w_re.data, seq_len=position_ids.max().item() + 1) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + n_kv_groups = n_head // n_kv_head + k = repeat_kv(k, n_kv_groups) + v = repeat_kv(v, n_kv_groups) + + # shape: (b * n_head, 1, head_dim) + q = q.permute(0, 2, 1, 3).reshape(b * n_head, tgt_s, head_dim) + # shape: (1, b * n_head, head_dim) + k_new = k.permute(1, 0, 2, 3).reshape(tgt_s, b * n_head, head_dim) + # shape: (1, b * n_head, head_dim) + v_new = v.permute(1, 0, 2, 3).reshape(tgt_s, b * n_head, head_dim) + + if isinstance(k_cache, TorchTensor): + if attn_sparsity >= 1.0: # Dense attention + if compress_cache: + # shape: (s, b * n_head, head_dim) + k = k_cache.device.decompress(k_cache)[:src_s] + v = v_cache.device.decompress(v_cache)[:src_s] + else: + # shape: (s, b * n_head, head_dim) + k = k_cache.data[:src_s] + v = v_cache.data[:src_s] + k[src_s - 1:src_s] = k_new + v[src_s - 1:src_s] = v_new + + # shape: (b * n_head, head_dim, s) + k = k.permute(1, 2, 0).reshape(b * n_head, head_dim, src_s) + # shape: (b * n_head, s, head_dim) + v = v.permute(1, 0, 2).reshape(b * n_head, src_s, head_dim) + + if k.is_cuda: + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim) + else: + q = q.float().cpu() + k, v = k.float(), v.float() + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim).cuda().half() + else: # Sparse attention + # shape: (s, b * n_head, head_dim) + k = k_cache.data[:src_s] + k[src_s - 1:src_s] = k_new + # shape: (b * n_head, head_dim, s) + k = k.permute(1, 2, 0).reshape(b * n_head, head_dim, src_s) + + if k.is_cuda: + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity) + else: + q = q.float().cpu() + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity).cuda().half() + else: # Mixed device attention + assert attn_sparsity >= 1.0 + value = self._mixed_device_attention(q, k_cache, v_cache, + k_new, v_new, attention_mask.data, b, src_s, tgt_s, + n_head, head_dim) + + # shape: (b, 1, h) + value = value.transpose(1, 2).view(b, tgt_s, h) + value = F.linear(value, w_out.data) + + value.add_(inputs.data) + + if donate[0]: inputs.delete() + if donate[1]: attention_mask.delete() + + if compress_cache: + if comp_config.group_dim == 0: + s_ = src_s // comp_config.group_size * comp_config.group_size + k_new = k[:, :, s_:].permute(2, 0, 1) + v_new = v[:, s_:, :].permute(1, 0, 2) + k_new = self.compressed_device.compress(k_new, comp_config) + v_new = self.compressed_device.compress(v_new, comp_config) + else: + k_new = TorchTensor.create_from_torch(k_new, self) + v_new = TorchTensor.create_from_torch(v_new, self) + + return TorchTensor.create_from_torch(value, self), k_new, v_new + + def llama_mlp(self, inputs, w_ln, w_g, w_u, w_d, eps, donate): + # decompress weights + if w_ln.device.device_type == DeviceType.COMPRESSED: + w_g = w_g.device.decompress(w_g) + w_u = w_g.device.decompress(w_u) + w_d = w_g.device.decompress(w_d) + + out = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + gate_out = F.linear(out, w_g.data) + F.silu(gate_out, inplace=True) + up_out = F.linear(out, w_u.data) + out = F.linear(gate_out * up_out, w_d.data) + out.add_(inputs.data) + if donate[0]: inputs.delete() + return TorchTensor.create_from_torch(out, self) From 8b89e0252348603f7d1b88fe68ae730f7bba1531 Mon Sep 17 00:00:00 2001 From: marswen Date: Fri, 29 Mar 2024 11:07:04 +0800 Subject: [PATCH 2/2] support qwen models --- flexgen/flex_qwen.py | 321 +++++++++++++++++++++++++++++++++++++ flexgen/pytorch_backend.py | 193 ++++++++++++++++++++++ flexgen/qwen_config.py | 127 +++++++++++++++ 3 files changed, 641 insertions(+) create mode 100644 flexgen/flex_qwen.py create mode 100644 flexgen/qwen_config.py diff --git a/flexgen/flex_qwen.py b/flexgen/flex_qwen.py new file mode 100644 index 00000000..6f3a6a99 --- /dev/null +++ b/flexgen/flex_qwen.py @@ -0,0 +1,321 @@ +""" +Usage: +python3 -m flexgen.flex_qwen --model Qwen/Qwen1.5-0.5B-Chat --gpu-batch-size 32 --percent 100 0 100 0 100 0 +""" +import os +import torch +import argparse +from typing import Union +from transformers import AutoTokenizer +from flexgen.compression import CompressionConfig +from flexgen.qwen_config import QwenConfig, get_qwen_config, download_qwen_weights +from flexgen.flex_llama import LlamaInputEmbed, LlamaOutputEmbed, LlamaMLP +from flexgen.pytorch_backend import QwenTorchDevice, TorchDisk, TorchMixedDevice, fix_recursive_import +from flexgen.flex_opt import (Policy, init_weight_list, SelfAttention, TransformerLayer, + OptLM, get_filename, get_test_inputs) +from flexgen.timer import timers +from flexgen.utils import (ExecutionEnv, GB, ValueHolder, + array_1d, array_2d, str2bool, project_decode_latency, write_benchmark_log) + +fix_recursive_import() + +DUMMY_WEIGHT = "_DUMMY_" # Use dummy weights for benchmark purposes + + +class QwenSelfAttention(SelfAttention): + def __init__(self, config, env, policy, layer_id): + super().__init__(config, env, policy, layer_id) + + def init_weight(self, weight_home, path): + h, n_head, n_kv_head, dtype = (self.config.input_dim, self.config.n_head, self.config.num_key_value_heads, self.config.dtype) + head_dim = h // n_head + path = os.path.join(os.path.join(path, f"layers.{self.layer_id}.")) + weight_specs = [ + # w_ln + ((h,), dtype, path + "input_layernorm.weight"), + # w_q + ((h, n_head*head_dim), dtype, path + "self_attn.q_proj.weight"), + # b_q + ((n_head*head_dim,), dtype, path + "self_attn.q_proj.bias"), + # w_k + ((n_kv_head*head_dim, h), dtype, path + "self_attn.k_proj.weight"), + # b_k + ((h,), dtype, path + "self_attn.k_proj.bias"), + # w_v + ((n_kv_head*head_dim, h), dtype, path + "self_attn.v_proj.weight"), + # b_v + ((h,), dtype, path + "self_attn.v_proj.bias"), + # w_o + ((n_head*head_dim, h), dtype, path + "self_attn.o_proj.weight"), + ] + weights = init_weight_list(weight_specs, self.policy, self.env) + weight_home.store(weights) + + def load_weight(self, weight_home, weight_read_buf, k): + w_ln, w_q, b_q, w_k, b_k, w_v, b_v, w_o = weight_home.val + if k == 0: + dst1 = self.weight_load_dst + dst2 = self.compute + weight_read_buf.store(( + w_ln.smart_copy(dst2), + w_q.smart_copy(dst1), b_q.smart_copy(dst2), + w_k.smart_copy(dst1), b_k.smart_copy(dst2), + w_v.smart_copy(dst1), b_v.smart_copy(dst2), + w_o.smart_copy(dst1))) + + def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, + cache_write_buf, i, k): + n_head = self.config.n_head + n_kv_head = self.config.num_key_value_heads + + donate = [False] * 12 + h, donate[0] = hidden.val, True + + if k == self.policy.num_gpu_batches - 1: + # Clear the weight_read_buf if it is the last gpu batch + ((w_ln, donate[2]), (w_q, donate[3]), (b_q, donate[4]), (w_k, donate[5]), (b_k, donate[6]), + (w_v, donate[7]), (b_v, donate[8]), (w_o, donate[9])) = weight_read_buf.pop() + else: + ((w_ln, _), (w_q, _), (b_q, _), (w_k, _), (b_k, _), (w_v, _), (b_v, _), + (w_o, _)) = weight_read_buf.val + + if i == 0: # prefill + mask, donate[1] = attention_mask.val.smart_copy(self.compute) + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 + h, new_k_cache, new_v_cache = self.compute.qwen_mha(h, position_ids, mask, w_ln, + w_q, b_q, w_k, b_k, w_v, b_v, w_o, n_head, n_kv_head, donate, self.config.rms_norm_eps, self.config.rope_theta, + self.policy.compress_cache, self.policy.comp_cache_config) + cache_write_buf.store((new_k_cache, new_v_cache)) + else: # decoding + mask, donate[1] = attention_mask.val.smart_copy(self.attention_compute) + (k_cache, donate[10]), (v_cache, donate[11]) = cache_read_buf.pop() + position_ids = torch.cumsum(mask.data, dim=1).int() * mask.data + 1 + position_ids = position_ids[:, -h.shape[1]].unsqueeze(1) + h, new_k_cache, new_v_cache = self.compute.qwen_mha_gen(h, position_ids, mask, w_ln, + w_q, b_q, w_k, b_k, w_v, b_v, w_o, self.config.rms_norm_eps, self.config.rope_theta, n_head, n_kv_head, + k_cache, v_cache, donate, self.policy.attn_sparsity, + self.policy.compress_cache, self.policy.comp_cache_config) + cache_write_buf.store((new_k_cache, new_v_cache)) + + hidden.val = h + + +class QwenTransformerLayer(TransformerLayer): + def __init__(self, config, env, policy, i): + self.attention = QwenSelfAttention(config, env, policy, i) + self.mlp = LlamaMLP(config, env, policy, i) + self.policy = policy + self.compute = self.attention.compute + + +class QwenLM(OptLM): + def __init__(self, + config: Union[str, QwenConfig], + env: ExecutionEnv, + path: str, + policy: Policy): + if isinstance(config, str): + config = get_qwen_config(config) + self.config = config + self.env = env + self.path = path + self.policy = policy + self.num_gpu_batches = policy.num_gpu_batches + + layers = [] + layers.append(LlamaInputEmbed(self.config, self.env, self.policy)) + for i in range(self.config.num_hidden_layers): + if policy.sep_layer: + layers.append(QwenSelfAttention(self.config, self.env, self.policy, i)) + layers.append(LlamaMLP(self.config, self.env, self.policy, i)) + else: + layers.append(QwenTransformerLayer(self.config, self.env, self.policy, i)) + layers.append(LlamaOutputEmbed(self.config, self.env, self.policy)) + self.layers = layers + self.num_layers = len(layers) + + if self.policy.act_gpu_percent == 100: + self.act_home = self.env.gpu + elif self.policy.act_cpu_percent == 100: + self.act_home = self.env.cpu + elif self.policy.act_disk_percent == 100: + self.act_home = self.env.disk + else: + raise NotImplementedError() + + # CUDA streams + self.load_weight_stream = torch.cuda.Stream() + self.load_cache_stream = torch.cuda.Stream() + self.store_cache_stream = torch.cuda.Stream() + + # Intermediate tensors + # The following buffers store values used + # for the i-th token, j-th layer, k-th gpu batch. + num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches + + # cache[j][k] + self.cache_home = array_2d(num_layers, num_gpu_batches, ValueHolder) + self.cache_read_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) + self.cache_write_buf = array_2d(num_layers, num_gpu_batches, ValueHolder) + # weight[j] + self.weight_read_buf = array_1d(num_layers, ValueHolder) + # attention_mask[k] + self.attention_mask = array_1d(num_gpu_batches, ValueHolder) + + self.task = None + self.init_all_weights() + + def init_weight(self, j): + expanded_path = os.path.abspath(os.path.expanduser( + os.path.join(self.path, f"{self.config.name}-np"))) + check_path = os.path.join(expanded_path, "embed_tokens.weight") + if not os.path.exists(check_path) and DUMMY_WEIGHT not in check_path: + download_qwen_weights(self.config.name, self.path) + + self.layers[j].init_weight(self.weight_home[j], expanded_path) + + +def run_flexgen(args): + print(f": args.model: {args.model}") + tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left") + tokenizer.pad_token_id = tokenizer.eos_token_id + num_prompts = args.num_gpu_batches * args.gpu_batch_size + prompt_len, gen_len, cut_gen_len = args.prompt_len, args.gen_len, args.cut_gen_len + + # Task and policy + warmup_inputs = get_test_inputs(32, num_prompts, tokenizer) + inputs = get_test_inputs(prompt_len, num_prompts, tokenizer) + + gpu = QwenTorchDevice("cuda:0") + cpu = QwenTorchDevice("cpu") + disk = TorchDisk(args.offload_dir) + env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + + policy = Policy(args.gpu_batch_size, args.num_gpu_batches, + args.percent[0], args.percent[1], + args.percent[2], args.percent[3], + args.percent[4], args.percent[5], + args.overlap, args.sep_layer, args.pin_weight, + args.cpu_cache_compute, args.attn_sparsity, + args.compress_weight, + CompressionConfig(num_bits=4, group_size=64, + group_dim=0, symmetric=False), + args.compress_cache, + CompressionConfig(num_bits=4, group_size=64, + group_dim=2, symmetric=False)) + assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented" + + qwen_config = get_qwen_config(args.model, pad_token_id=tokenizer.eos_token_id) + cache_size = qwen_config.cache_bytes(num_prompts, prompt_len + gen_len) + hidden_size = qwen_config.hidden_bytes(num_prompts, prompt_len + gen_len) + print(f"model size: {qwen_config.model_bytes()/GB:.3f} GB, " + f"cache size: {cache_size/GB:.3f} GB, " + f"hidden size (prefill): {hidden_size/GB:.3f} GB") + + print("init weight...") + model = QwenLM(qwen_config, env, args.path, policy) + + try: + print("warmup - generate") + output_ids = model.generate( + warmup_inputs, max_new_tokens=1, verbose=args.verbose) + + print("benchmark - generate") + timers("generate").reset() + output_ids = model.generate( + inputs, max_new_tokens=args.gen_len, + debug_mode=args.debug_mode, cut_gen_len=cut_gen_len, verbose=args.verbose) + costs = timers("generate").costs + finally: + env.close_copy_threads() + + # Log output + prefill_latency = costs[0] + prefill_throughput = num_prompts * prompt_len / prefill_latency + if cut_gen_len: # project latency of cut_gen_len to gen_len + decode_latency = project_decode_latency(costs, prompt_len, gen_len) + else: + decode_latency = sum(costs[1:]) + decode_throughput = num_prompts * (gen_len - 1) / max(decode_latency, 1e-10) + num_generated_tokens = num_prompts * gen_len + total_latency = prefill_latency + decode_latency + total_throughput = num_generated_tokens / total_latency + _, gpu_peak_mem = gpu.mem_stats() + _, cpu_peak_mem = cpu.mem_stats() + + if DUMMY_WEIGHT not in args.path: + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + show_str = "Outputs:\n" + 70 * '-' + "\n" + for i in [0, len(outputs)-1]: + show_str += f"{i}: {outputs[i]}\n" + show_str += "-" * 70 + "\n" + if args.verbose >= 2: + print(show_str) + + gpu.print_stats() + cpu.print_stats() + projected = bool(args.debug_mode or cut_gen_len) + + if args.log_file == "auto": + filename = get_filename(args) + ".log" + else: + filename = args.log_file + + log_str = write_benchmark_log(filename, + qwen_config.model_bytes(), cache_size, hidden_size, + gpu_peak_mem, projected, prefill_latency, prefill_throughput, + decode_latency, decode_throughput, total_latency, total_throughput) + if args.verbose >= 1: + print(log_str) + + +def add_parser_arguments(parser): + parser.add_argument("--model", type=str, default="Qwen/Qwen1.5-7B-Chat", + help="The model name.") + parser.add_argument("--path", type=str, default="~/qwen_weights", + help="The path to the model weights. If there are no cached weights, " + "FlexGen will automatically download them from HuggingFace.") + parser.add_argument("--offload-dir", type=str, default="~/flexgen_offload_dir", + help="The directory to offload tensors. ") + parser.add_argument("--prompt-len", type=int, default=512) + parser.add_argument("--gen-len", type=int, default=32) + parser.add_argument("--cut-gen-len", type=int, + help="Cut generation length for fast debugging.") + parser.add_argument("--debug-mode", type=str, + choices=["fewer_batch", "breakdown"]) + parser.add_argument("--gpu-batch-size", type=int, default=4) + parser.add_argument("--num-gpu-batches", type=int, default=1) + parser.add_argument("--percent", nargs="+", type=int, + default=[100, 0, 100, 0, 100, 0], + help="Six numbers. They are " + "the percentage of weight on GPU, " + "the percentage of weight on CPU, " + "the percentage of attention cache on GPU, " + "the percentage of attention cache on CPU, " + "the percentage of activations on GPU, " + "the percentage of activations on CPU") + parser.add_argument("--sep-layer", type=str2bool, nargs='?', + const=True, default=True) + parser.add_argument("--pin-weight", type=str2bool, nargs="?", + const=True, default=True) + parser.add_argument("--cpu-cache-compute", action="store_true") + parser.add_argument("--attn-sparsity", type=float, default=1.0) + parser.add_argument("--compress-weight", action="store_true", + help="Whether to compress weight.") + parser.add_argument("--compress-cache", action="store_true", + help="Whether to compress cache.") + parser.add_argument("--log-file", type=str, default="auto") + parser.add_argument("--no-log", action="store_true") + parser.add_argument("--verbose", type=int, default=2) + parser.add_argument("--overlap", type=str2bool, nargs='?', + const=True, default=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_parser_arguments(parser) + args = parser.parse_args() + + assert len(args.percent) == 6 + + run_flexgen(args) diff --git a/flexgen/pytorch_backend.py b/flexgen/pytorch_backend.py index 5e8c38ea..af50f278 100644 --- a/flexgen/pytorch_backend.py +++ b/flexgen/pytorch_backend.py @@ -1211,3 +1211,196 @@ def llama_mlp(self, inputs, w_ln, w_g, w_u, w_d, eps, donate): out.add_(inputs.data) if donate[0]: inputs.delete() return TorchTensor.create_from_torch(out, self) + + +class QwenTorchDevice(LlamaTorchDevice): + + def qwen_mha(self, inputs, position_ids, attention_mask, w_ln, w_q, b_q, w_k, b_k, w_v, b_v, + w_out, n_head, n_kv_head, donate, eps, rope_theta, compress_cache, comp_config): + """Multi-head attention (prefill phase).""" + # decompress weights + if w_q.device.device_type == DeviceType.COMPRESSED: + w_q = w_q.device.decompress(w_q) + w_k = w_k.device.decompress(w_k) + w_v = w_v.device.decompress(w_v) + w_out = w_out.device.decompress(w_out) + + b, s, h = inputs.shape + head_dim = h // n_head + scaling = head_dim ** -0.5 + + hidden = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + + # shape: (b, s, h) + q = F.linear(hidden, w_q.data, b_q.data) * scaling + k = F.linear(hidden, w_k.data, b_k.data) + v = F.linear(hidden, w_v.data, b_v.data) + # shape: (b, s, n_head, head_dim) + q = q.view(b, s, n_head, head_dim) + k = k.view(b, s, n_kv_head, head_dim) + v = v.view(b, s, n_kv_head, head_dim) + + kv_seq_len = k.shape[-3] + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + cos, sin = rotary_embedding(v, inv_freq, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + n_kv_groups = n_head // n_kv_head + k = repeat_kv(k, n_kv_groups) + v = repeat_kv(v, n_kv_groups) + + # shape: (b * n_head, s, head_dim) + q = q.permute(0, 2, 1, 3).reshape(b * n_head, s, head_dim) + # shape: (b * n_head, head_dim, s) + k = k.permute(0, 2, 3, 1).reshape(b * n_head, head_dim, s) + # shape: (b * n_head, s, head_dim) + v = v.permute(0, 2, 1, 3).reshape(b * n_head, s, head_dim) + + # shape: (b * n_head, s, s) + attn_weights = torch.bmm(q, k) + + # shape: (b, 1, s, s) + idx = torch.arange(s, device=self.dev) + causal_mask = (idx <= idx.view(s, 1)).view(1, 1, s, s) + mask = attention_mask.data.view(b, 1, 1, s) & causal_mask + + # shape: (b, n_head, s, s) + attn_weights = attn_weights.view(b, n_head, s, s) + attn_weights = torch.where(mask, attn_weights, -1e4) + attn_weights = attn_weights.view(b * n_head, s, s) + attn_weights = F.softmax(attn_weights, dim=2) + # shape: (b, n_head, s, head_dim) + value = torch.bmm(attn_weights, v).view(b, n_head, s, head_dim) + # shape: (b, s, h) + value = value.transpose(1, 2).reshape(b, s, h) + value = F.linear(value, w_out.data) + + value.add_(inputs.data) + + if donate[0]: inputs.delete() + if donate[1]: attention_mask.delete() + + # (s, b * n_head, head_dim) + k = k.permute(2, 0, 1) + v = v.permute(1, 0, 2) + + if compress_cache: + k = self.compressed_device.compress(k, comp_config) + v = self.compressed_device.compress(v, comp_config) + else: + k = TorchTensor.create_from_torch(k, self) + v = TorchTensor.create_from_torch(v, self) + + return TorchTensor.create_from_torch(value, self), k, v + + def qwen_mha_gen(self, inputs, position_ids, attention_mask, w_ln, w_q, b_q, w_k, b_k, w_v, b_v, + w_out, eps, rope_theta, n_head, n_kv_head, k_cache, v_cache, donate, + attn_sparsity, compress_cache, comp_config): + """Multi-head attention (decoding phase).""" + # decompress weights + if w_q.device.device_type == DeviceType.COMPRESSED: + w_q = w_q.device.decompress(w_q) + w_k = w_k.device.decompress(w_k) + w_v = w_v.device.decompress(w_v) + w_out = w_out.device.decompress(w_out) + + b, tgt_s, h = inputs.shape + src_s = attention_mask.shape[1] + head_dim = h // n_head + scaling = head_dim ** -0.5 + + hidden = rms_norm(inputs.data, weight=w_ln.data, eps=eps) + + # shape: (b, 1, h) + q = F.linear(hidden, w_q.data, b_q.data) * scaling + k = F.linear(hidden, w_k.data, b_k.data) + v = F.linear(hidden, w_v.data, b_v.data) + # shape: (b, 1, n_head, head_dim) + q = q.view(b, tgt_s, n_head, head_dim) + k = k.view(b, tgt_s, n_kv_head, head_dim) + v = v.view(b, tgt_s, n_kv_head, head_dim) + + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + cos, sin = rotary_embedding(v, inv_freq, seq_len=position_ids.max().item() + 1) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + n_kv_groups = n_head // n_kv_head + k = repeat_kv(k, n_kv_groups) + v = repeat_kv(v, n_kv_groups) + + # shape: (b * n_head, 1, head_dim) + q = q.permute(0, 2, 1, 3).reshape(b * n_head, tgt_s, head_dim) + # shape: (1, b * n_head, head_dim) + k_new = k.permute(1, 0, 2, 3).reshape(tgt_s, b * n_head, head_dim) + # shape: (1, b * n_head, head_dim) + v_new = v.permute(1, 0, 2, 3).reshape(tgt_s, b * n_head, head_dim) + + if isinstance(k_cache, TorchTensor): + if attn_sparsity >= 1.0: # Dense attention + if compress_cache: + # shape: (s, b * n_head, head_dim) + k = k_cache.device.decompress(k_cache)[:src_s] + v = v_cache.device.decompress(v_cache)[:src_s] + else: + # shape: (s, b * n_head, head_dim) + k = k_cache.data[:src_s] + v = v_cache.data[:src_s] + k[src_s - 1:src_s] = k_new + v[src_s - 1:src_s] = v_new + + # shape: (b * n_head, head_dim, s) + k = k.permute(1, 2, 0).reshape(b * n_head, head_dim, src_s) + # shape: (b * n_head, s, head_dim) + v = v.permute(1, 0, 2).reshape(b * n_head, src_s, head_dim) + + if k.is_cuda: + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim) + else: + q = q.float().cpu() + k, v = k.float(), v.float() + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim).cuda().half() + else: # Sparse attention + # shape: (s, b * n_head, head_dim) + k = k_cache.data[:src_s] + k[src_s - 1:src_s] = k_new + # shape: (b * n_head, head_dim, s) + k = k.permute(1, 2, 0).reshape(b * n_head, head_dim, src_s) + + if k.is_cuda: + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity) + else: + q = q.float().cpu() + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity).cuda().half() + else: # Mixed device attention + assert attn_sparsity >= 1.0 + value = self._mixed_device_attention(q, k_cache, v_cache, + k_new, v_new, attention_mask.data, b, src_s, tgt_s, + n_head, head_dim) + + # shape: (b, 1, h) + value = value.transpose(1, 2).view(b, tgt_s, h) + value = F.linear(value, w_out.data) + + value.add_(inputs.data) + + if donate[0]: inputs.delete() + if donate[1]: attention_mask.delete() + + if compress_cache: + if comp_config.group_dim == 0: + s_ = src_s // comp_config.group_size * comp_config.group_size + k_new = k[:, :, s_:].permute(2, 0, 1) + v_new = v[:, s_:, :].permute(1, 0, 2) + k_new = self.compressed_device.compress(k_new, comp_config) + v_new = self.compressed_device.compress(v_new, comp_config) + else: + k_new = TorchTensor.create_from_torch(k_new, self) + v_new = TorchTensor.create_from_torch(v_new, self) + + return TorchTensor.create_from_torch(value, self), k_new, v_new diff --git a/flexgen/qwen_config.py b/flexgen/qwen_config.py new file mode 100644 index 00000000..0a592d94 --- /dev/null +++ b/flexgen/qwen_config.py @@ -0,0 +1,127 @@ +""" +The Qwen model configurations and weight downloading utilities. + +adopted from opt_config.py +""" + +import dataclasses +import glob +import os +import numpy as np +from tqdm import tqdm + + +@dataclasses.dataclass(frozen=True) +class QwenConfig: + name: str = "Qwen1.5-7B" + hidden_act: str = "silu" + input_dim: int = 4096 + initializer_range: float = 0.02 + intermediate_size: int = 11008 + max_position_embeddings: int = 4096 + n_head: int = 32 + num_hidden_layers: int = 32 + num_key_value_heads: int = 32 + rms_norm_eps: float = 1e-06 + rope_theta: float = 1000000.0 + dtype: type = np.float16 + pad_token_id: int = 151643 + vocab_size: int = 151936 + + def model_bytes(self): + h = self.input_dim + intermediate = self.intermediate_size + n_head = self.n_head + head_dim = h // n_head + return 2 * (self.vocab_size * h + + self.num_hidden_layers * ( + # self-attention + h * (3 * h + 1) + h * h + + # mlp + 3 * h * intermediate + + # layer norm + 2 * h) + + # head + h + self.vocab_size * h) + + def cache_bytes(self, batch_size, seq_len): + return 2 * batch_size * seq_len * self.num_hidden_layers * self.input_dim * 2 + + def hidden_bytes(self, batch_size, seq_len): + return batch_size * seq_len * self.input_dim * 2 + + +def get_qwen_config(name, **kwargs): + if "/" in name: + name = name.split("/")[1] + + if "-Chat" in name: + arch_name = name.replace("-Chat", "") + else: + arch_name = name + + if arch_name == "Qwen1.5-0.5B": + config = QwenConfig(name=name, + input_dim=1024, intermediate_size=2816, n_head=16, rms_norm_eps=1e-6, rope_theta=1000000.0, + num_hidden_layers=24, num_key_value_heads=16, vocab_size=151936 + ) + elif arch_name == "Qwen1.5-1.8B": + config = QwenConfig(name=name, + input_dim=2048, intermediate_size=5504, n_head=16, rms_norm_eps=1e-6, rope_theta=1000000.0, + num_hidden_layers=24, num_key_value_heads=16, vocab_size=151936 + ) + elif arch_name == "Qwen1.5-4B": + config = QwenConfig(name=name, + input_dim=2560, intermediate_size=6912, n_head=20, rms_norm_eps=1e-6, rope_theta=5000000.0, + num_hidden_layers=40, num_key_value_heads=20, vocab_size=151936 + ) + elif arch_name == "Qwen1.5-7B": + config = QwenConfig(name=name, + input_dim=4096, intermediate_size=11008, n_head=32, rms_norm_eps=1e-6, rope_theta=1000000.0, + num_hidden_layers=32, num_key_value_heads=32, vocab_size=151936 + ) + elif arch_name == "Qwen1.5-14B": + config = QwenConfig(name=name, + input_dim=5120, intermediate_size=13696, n_head=40, rms_norm_eps=1e-6, rope_theta=1000000.0, + num_hidden_layers=40, num_key_value_heads=40, vocab_size=152064 + ) + elif arch_name == "Qwen1.5-72B": + config = QwenConfig(name=name, + input_dim=8192, intermediate_size=24576, n_head=64, rms_norm_eps=1e-5, rope_theta=1000000.0, + num_hidden_layers=80, num_key_value_heads=64, vocab_size=152064 + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + + +def download_qwen_weights(model_name, path): + import torch + from huggingface_hub import snapshot_download + from safetensors import safe_open + + print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " + f"The downloading and cpu loading can take dozens of minutes. " + f"If it seems to get stuck, you can monitor the progress by " + f"checking the memory usage of this process.") + + hf_model_name = "Qwen/" + model_name + + folder = snapshot_download(hf_model_name, allow_patterns="*.safetensors") + safetensors_files = glob.glob(os.path.join(folder, "*.safetensors")) + + if "/" in model_name: + model_name = model_name.split("/")[1] + path = os.path.join(path, f"{model_name}-np") + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(path, exist_ok=True) + + for safetensors_file in tqdm(safetensors_files, desc="Convert format"): + with safe_open(safetensors_file, framework='pt') as stf: + for name in tqdm(stf.keys(), leave=False): + param = stf.get_tensor(name) + name = name.replace("model.", "") + param_path = os.path.join(path, name) + with open(param_path, "wb") as f: + np.save(f, param.to(torch.float16).cpu().detach().numpy())