Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank][llama] enable quark parity test on mi300x #886

Merged
merged 6 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ jobs:
--with-flux-data \
--with-t5-data \
--with-vae-data \
--with-quark-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
sharktank/tests/models/flux/flux_test.py \
sharktank/tests/models/vae/vae_test.py \
sharktank/tests/models/llama/quark_parity_test.py \
--durations=0 \
--timeout=600
--timeout=800


test_integration:
Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def pytest_addoption(parser):
),
)

parser.addoption(
"--with-quark-data",
action="store_true",
default=False,
help=(
"Enable tests that use vae data such as models not part of the source code."
),
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
Expand Down
21 changes: 11 additions & 10 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,17 @@ def main():
intermediates_saver.save_file(
args.save_intermediates_path + "_prefill.safetensors"
)
counter = 0
while not batch.done:
batch.decode()
if args.save_intermediates_path:
intermediates_saver.save_file(
args.save_intermediates_path + f"_step_{counter}.safetensors"
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1
if not args.skip_decode:
counter = 0
while not batch.done:
batch.decode()
if args.save_intermediates_path:
intermediates_saver.save_file(
args.save_intermediates_path + f"_step_{counter}.safetensors"
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1


if __name__ == "__main__":
Expand Down
64 changes: 51 additions & 13 deletions sharktank/tests/models/llama/quark_parity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os

from safetensors import safe_open
import torch
import unittest
import pytest
from pathlib import Path
import subprocess

with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')")


@pytest.mark.skip(reason="need to generate values to compare against")
class QuarkParityTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.path_prefix = Path("/shark-dev/quark_test")

@with_quark_data
def test_compare_against_quark(self):
def both(key, index=None):
o = ours[key]
t = theirs[key]
if index is None:
return o, t
else:
return o[index], t[index]
sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent
)
our_path = self.path_prefix / "ours_prefill.safetensors"
if os.path.exists(our_path):
os.remove(our_path)

mapping = dict()
for i in range(32):
Expand All @@ -41,18 +48,49 @@ def both(key, index=None):
mapping[a] = b
mapping[a + "_input_0"] = b + "_input_0"

command = [
"python",
"-m",
"sharktank.examples.paged_llm_v1",
"The capitol of Texas is",
f"--irpa-file={self.path_prefix}/fp8_bf16_weight.irpa",
f"--tokenizer-config-json=/data/llama3.1/8b/tokenizer.json",
"--fake-quant",
"--attention-kernel=torch",
"--activation-dtype=bfloat16",
f"--save_intermediates_path={self.path_prefix}/ours",
"--use-hf",
"--attention-dtype=bfloat16",
"--skip-decode",
"--block-seq-stride=16",
]
command = subprocess.list2cmdline(command)
proc = subprocess.run(
command, shell=True, capture_output=True, cwd=sharktank_dir
)

ours = dict()
with safe_open("../ours_newest_prefill.safetensors", "pytorch") as st:
with safe_open(our_path, "pytorch") as st:
for key in st.keys():
ours[key] = st.get_tensor(key)

theirs = dict()
with safe_open("../theirs2.safetensors", "pytorch") as st:
golden = dict()
golden_path = self.path_prefix / "golden.safetensors"
with safe_open(golden_path, "pytorch") as st:
for key in st.keys():
if key in mapping:
theirs[mapping[key]] = st.get_tensor(key)
golden[mapping[key]] = st.get_tensor(key)

test_layers = [v for k, v in mapping.items()]

def both(key, index=None):
o = ours[key]
t = golden[key]
if index is None:
return o, t
else:
return o[index], t[index]

for lyr in test_layers:
name = lyr
if name in ours.keys() and name != "freqs":
Expand Down
Loading