Skip to content

Commit

Permalink
not mergeable as-is
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Feb 2, 2025
1 parent 8300bc8 commit a98a332
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
9 changes: 4 additions & 5 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from iree.compiler.ir import IntegerType
from iree.compiler.ir import IntegerType, FloatType

__all__ = [
"batch_matmul_transpose_b",
Expand Down Expand Up @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection):
lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})",
)
# Shape batch, m, n
c_desc = ksel.return_new_tensor(
[lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype
)
c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32)
specialize_all_known_dims(lhs_desc)
specialize_all_known_dims(rhs_desc)
specialize_all_known_dims(c_desc)
Expand All @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
result_desc = ksel.result_descs[0]

# Generate specialization signature and types.
a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type)
a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type)
b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type)
accum_type = FloatType.parse("f32")
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def forward(self, x):
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor):
if x.unpack().qs.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.bfloat16)
return y
if qdq_output is not None:
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def qlinear_tensor_scaled(

# Fall back to automatic fusion based on integer, high precision matmul.
y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype)
return y_qs

# Offset correction. By applying the offset correction in post, it is
# set up to fuse with its consumer, which is already doing additional
Expand Down Expand Up @@ -187,9 +188,8 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
y_qs = kernels.batch_matmul_transpose_b(lhs, rhs)
return y_qs
# Squeeze the batch dimension to maintain shape parity with other
# layers.
if len(y_qs.shape) > 2:
Expand Down

0 comments on commit a98a332

Please sign in to comment.