Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom generation script #21

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:23.01-py3
FROM nvcr.io/nvidia/pytorch:23.03-py3

ARG USER=1000
ARG USERNAME=user
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ bitsandbytes
safetensors
deepspeed==0.7.7
-e ./transformers
flash-attn
einops

# TODO: Analysis only
py-markdown-table
Expand Down
20 changes: 20 additions & 0 deletions scripts/run_all_benchmark_breakdown.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

# Santacoder
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 v2_
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 v2_
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 v2_

./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 v2_
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 v2_
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_

# Large model
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_# OOM?

./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM?
76 changes: 76 additions & 0 deletions scripts/run_benchmark_breakdown.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

# Santacoder prefill.
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0
# Santacoder decode (fewer data points because slower)
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1
MODEL_NAME=${1:-"santacoder"}
MODEL_PATH=${2:-"bigcode/gpt_bigcode-santacoder"}
BATCH_SIZE=${3:-32}
MAX_NEW_TOKENS=${4:-2040}
# Prime number to see key length padding effect.
TOKEN_STEP=${5:-5}
STEP_ID=${6:-""}
FILE_PREFIX=${7:-""}
CYCLES=${8:-10}

SAVE_DIR=data/benchmarks/v2
#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256"
RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom"


RUNTIME=("" "pre_allocate_kv_cache=True" "pre_allocate_kv_cache=True inference_runner=3")
RUNTIME_NAMES=("base" "pre_allocate" "graph")

ATTN=( \
"attention_implementation=0" \
"attention_implementation=1" \
"attention_implementation=1 --pad_generated_tokens=0.5" \
"attention_implementation=2" \
"attention_implementation=0 fused_softmax=False" \
"attention_implementation=0 fused_softmax=True" \
"attention_implementation=3" \
"attention_implementation=4" \
"attention_implementation=5" \
)
ATTN_NAME=( \
"default" \
"flash" \
"flash_unpad_50" \
"torch" \
"no_jit" \
"jit" \
"torchflash" \
"torchmem" \
"torchcpp" \
)


STEP=("--no_prefill" "--no_cache")
STEP_NAME=("decode" "prefill")

COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=$CYCLES --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE predict_last_token=True"

run () { # run(step, runtime, attn)
FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"$FILE_PREFIX""${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json
if [ -f "$FILE_NAME" ];
then
echo "Skipping existing $FILE_NAME"
else
$RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save="$FILE_NAME"
fi
}

if [ "${STEP_ID}" -eq "0" ]
then
# Decode (default attn only)
for runtime in {0..2}
do
run 0 $runtime 0
done
else
# Prefill (all runtimes are the same)
for attn in {0..2}
do
run 1 0 $attn
done
fi
59 changes: 46 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src.metrics import Metrics
from src.pipeline import Pipeline, get_pipeline_class
from src.profile import get_profiler, logger
from src.utils import configure_logging, get_dummy_batch, log_dict, log_rank_n, parse_config_args
from src.utils import configure_logging, get_input_batch, log_dict, log_rank_n, parse_config_args


def get_arg_parser() -> ArgumentParser:
Expand All @@ -26,16 +26,25 @@ def get_arg_parser() -> ArgumentParser:
parser.add_argument("config_args", nargs="*")

# Runtime
parser.add_argument("-c", "--custom_generate", action="store_true")
parser.add_argument("--pipeline_class", default="HF_Pipeline")
parser.add_argument("--device", default="cuda", type=torch.device)
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
parser.add_argument("--local_rank", type=int)
parser.add_argument("--no_fast_init", dest="fast_init", action="store_false")
parser.add_argument("--no_fast_init", "--nf", dest="fast_init", action="store_false")
parser.add_argument("--no_cache", "--nc", dest="use_cache", action="store_false")
parser.add_argument("--no_prefill", "--np", dest="do_prefill", action="store_false")
parser.add_argument("--key_length_step", "--ks", default=1, type=int)
parser.add_argument("--ignore_oom", "--oom", action="store_true")

# Input and output
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_input_length", default=-1, type=int)
parser.add_argument("--max_new_tokens", default=100, type=int)
parser.add_argument("--batch_size", "-b", default=1, type=int)
parser.add_argument("--max_input_length", "-i", default=-1, type=int)
parser.add_argument("--sample_dir", "-d")
parser.add_argument("--input_pad_ratio", "--pad", default=0, type=float)
parser.add_argument("--pad_generated_tokens", "--pad_g", default=0, type=float)
parser.add_argument("--input_seed", "--seed", default=0, type=int)
parser.add_argument("--max_new_tokens", "-g", default=100, type=int)

# Cleanup
parser.add_argument("--clear_every_run", action="store_true")
Expand All @@ -47,10 +56,11 @@ def get_arg_parser() -> ArgumentParser:

# Profiling and logging
parser.add_argument("--max_log_outputs", type=int)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile_cycles", type=int)
parser.add_argument("--full_trace", action="store_true")
parser.add_argument("--show_op_names", action="store_true")
parser.add_argument("--breakdown_latency", "--bl", action="store_true")
parser.add_argument("--profile", "-p", action="store_true")
parser.add_argument("--profile_cycles", "--pc", type=int)
parser.add_argument("--full_trace", "--pt", action="store_true")
parser.add_argument("--show_op_names", "--pn", action="store_true")
parser.add_argument("--save", type=Path)

return parser
Expand All @@ -61,8 +71,6 @@ def main(argv: Optional[List[str]] = None) -> None:
parser = get_arg_parser()
args = parser.parse_args(argv)
config_args = parse_config_args(args.config_args)
generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False}
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
separate_profile = args.profile and args.profile_cycles is not None
warmup = args.profile if args.warmup is None else args.warmup
if separate_profile:
Expand All @@ -89,6 +97,14 @@ def main(argv: Optional[List[str]] = None) -> None:
fast_init=args.fast_init,
trust_remote_code=args.trust_remote_code,
)
inputs = get_input_batch(
args.batch_size,
args.max_input_length,
pipeline.tokenizer,
args.input_pad_ratio,
args.input_seed,
args.sample_dir,
)

all_metrics = []

Expand All @@ -104,7 +120,7 @@ def main(argv: Optional[List[str]] = None) -> None:
profiler = contextlib.nullcontext()

benchmark_metrics = {
**generate_kwargs,
"max_new_tokens": args.max_new_tokens,
"Model parameters": pipeline.get_num_parameters(),
"Cycles (warmup)": args.skip + warmup,
"Cycles (benchmark)": args.cycles,
Expand All @@ -121,10 +137,27 @@ def main(argv: Optional[List[str]] = None) -> None:
t1 = time.perf_counter()
with profiler as p:
for step in range(args.skip + warmup + args.cycles):
log_rank_n(
(
f"*** Running generation step {step} "
f"({'skip' if step<args.skip else 'warmup' if step<args.skip + warmup else 'benchmark'})"
),
logger.info,
)
if step == args.skip + warmup:
t2 = time.perf_counter()
benchmark_metrics[Metrics.RUNTIME_WARMUP] = t2 - t1
generated_text, metrics = pipeline(inputs, **generate_kwargs)
generated_text, metrics = pipeline(
inputs,
args.max_new_tokens,
custom_generate=args.custom_generate,
use_cache=args.use_cache,
do_prefill=args.do_prefill,
breakdown_latency=args.breakdown_latency,
key_length_step=args.key_length_step,
ignore_oom=args.ignore_oom,
pad_generated_tokens=args.pad_generated_tokens,
)
if args.profile:
p.step()

Expand Down
10 changes: 9 additions & 1 deletion src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@ def format_ms(t: float) -> str:
return f"{1000 * t:.2f} ms"


def format_ms_dict(t_dict: Dict[str, float]) -> Dict[str, str]:
return {key: format_ms(value) for key, value in t_dict.items()}


def format_mib(m: float) -> str:
return f"{m/2**20:.0f} MiB"


class Metrics:
LATENCY_E2E = "Latency (end to end)"
LATENCY_TOKEN = "Latency (tokenization)"
LATENCY_MODEL = "Latency (model)"
LATENCY_MODEL = "Latency (generate)"
LATENCY_GENERATE_START = "Latency (prepare for generation)"
LATENCY_GENERATE_BREAKDOWN = "Latency (generate breakdown)"
LATENCY_DECODE = "Latency (decode)"
LATENCY_MAX = "Latency (max)"
LATENCY_MIN = "Latency (min)"
Expand Down Expand Up @@ -59,6 +65,8 @@ class Metrics:
LATENCY_E2E: format_ms,
LATENCY_TOKEN: format_ms,
LATENCY_MODEL: format_ms,
LATENCY_GENERATE_START: format_ms,
LATENCY_GENERATE_BREAKDOWN: format_ms_dict,
LATENCY_DECODE: format_ms,
LATENCY_MAX: format_ms,
LATENCY_MIN: format_ms,
Expand Down
81 changes: 81 additions & 0 deletions src/parse_breakdown_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional


def get_arg_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("input_dir", type=Path)
parser.add_argument("--title")
parser.add_argument("--size", nargs=2, type=float)
parser.add_argument("--save_dir", "--save", type=Path)
return parser


def read_data(input_file: Path):
try:
with input_file.open("r") as f:
data = json.load(f)
data = {**data["config"], **data["results"]}
except (ValueError, OSError) as e:
raise ValueError(f"Cannot parse file {input_file} ({e})")
data["Setting"] = input_file.stem
return data


def plot(data, title=None, size=None):
import matplotlib.pyplot as plt

fig = plt.figure(figsize=size)
ax = fig.add_subplot()

cmap = plt.get_cmap("tab20").colors
cmap = cmap[::2] + cmap[1::2]

for i, dat in enumerate(data):
latency_data = dat["Latency (generate breakdown)"]
ax.plot(
[int(k) for k in latency_data.keys()],
[v * 1000 for v in latency_data.values()],
label=dat["Setting"],
linewidth=1,
color=cmap[i],
) # , linestyle=":")#, markersize=1, marker="o")

ax.set_title(title)
ax.set_xlabel("Sequence length")
ax.set_ylabel("Latency (ms)")
ax.legend()
return fig


def main(argv: Optional[List[str]] = None) -> None:
parser = get_arg_parser()
args = parser.parse_args(argv)
data = [read_data(input_file) for input_file in args.input_dir.iterdir()]

if len(data) == 0:
raise RuntimeError(f"No data to show.")

title = args.title
dirname = args.input_dir.stem
if title is None:
try:
name, _, bs, _, _, _, _, step = dirname.rsplit("_", 7)
title = f"{name} {step}, bs = {bs}"
except ValueError:
title = dirname

fig = plot(data, title, args.size)
fig.show()
if args.save_dir:
save_path = (args.save_dir / dirname).with_suffix(".jpg")
fig.savefig(save_path)
print(f"Figure saved to {save_path}")

input("Press enter to continue")


if __name__ == "__main__":
main()
Loading