From a98a332658700f1e4a5c576d96c1c98088bab8b8 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 1 Feb 2025 18:33:28 -0800 Subject: [PATCH] not mergeable as-is --- sharktank/sharktank/kernels/batch_matmul_transpose_b.py | 9 ++++----- sharktank/sharktank/layers/linear.py | 4 ++-- sharktank/sharktank/ops/qlinear_impls.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..a55d6654b 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -8,7 +8,7 @@ import torch -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, FloatType __all__ = [ "batch_matmul_transpose_b", @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection): lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})", ) # Shape batch, m, n - c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype - ) + c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) specialize_all_known_dims(c_desc) @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) + accum_type = FloatType.parse("f32") spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..dae126767 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -85,8 +85,8 @@ def forward(self, x): # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: + if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor): + if x.unpack().qs.dtype == torch.float8_e4m3fnuz: y = ops.to(y, torch.bfloat16) return y if qdq_output is not None: diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f4f7ac0ca..df6d74b15 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -93,6 +93,7 @@ def qlinear_tensor_scaled( # Fall back to automatic fusion based on integer, high precision matmul. y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) + return y_qs # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -187,9 +188,8 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs) + return y_qs # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: