-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathrun_streaming.py
148 lines (127 loc) · 4.91 KB
/
run_streaming.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import warnings
warnings.filterwarnings("ignore")
import torch
import argparse
import json
import os
import time
import re
import sys
from tqdm import tqdm
from streaming_llm.utils import load, download_url, load_jsonl
from transformers.models.llama.modeling_llama import LlamaAttention
from utils_real_drop.modify_llama import H2OLlamaAttention_streaming, H2OLlamaForCausalLM_streaming
@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
pos = 0
for _ in range(max_gen_len - 1):
outputs = model(
input_ids=pred_token_idx,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
generated_text = (
tokenizer.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)
now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now
if pred_token_idx == tokenizer.eos_token_id:
break
print(" ".join(generated_text[pos:]), flush=True)
return past_key_values
@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)
@torch.no_grad()
def streaming_inference_heavy_hitter(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
for name, m in model.named_modules():
if isinstance(m, H2OLlamaAttention):
layer_idx = int(name.split(".")[2])
past_key_values[layer_idx] = m.kv_cache.evict_for_space(past_key_values[layer_idx], space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)
def main(args):
model_name_or_path = args.model_name_or_path
model, tokenizer = load(model_name_or_path, args.enable_streaming_with_H2O, args)
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
print(f"Loading data from {test_filepath} ...")
if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)
list_data = load_jsonl(test_filepath)
prompts = []
for sample in list_data:
prompts += sample["turns"]
if args.enable_streaming_with_H2O:
kv_cache = None
streaming_inference_heavy_hitter(
model,
tokenizer,
prompts,
kv_cache,
)
else:
kv_cache = None
streaming_inference(
model,
tokenizer,
prompts,
kv_cache,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path", type=str, default="lmsys/vicuna-13b-v1.3"
)
parser.add_argument("--data_root", type=str, default="data/")
parser.add_argument("--enable_streaming_with_H2O", action="store_true")
parser.add_argument("--start_size", type=int, default=4)
parser.add_argument("--heavy_hitter_size", type=int, default=4)
parser.add_argument("--recent_size", type=int, default=2000)
args = parser.parse_args()
main(args)