diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py index ef0a72880..528ca9100 100644 --- a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -44,7 +44,7 @@ import torch from torch._decomp import get_decompositions import torch_mlir -from torch_mlir import TensorPlaceholder +from torch_mlir.torchscript import TensorPlaceholder from tqdm import tqdm from brevitas.backport.fx._symbolic_trace import wrap @@ -313,7 +313,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir if is_first: ts_g = compile_vicuna_layer( export_context_manager, export_class, layer, inputs[0], inputs[1], inputs[2]) - module = torch_mlir.compile( + module = torch_mlir.torchscript.compile( ts_g, (hidden_states_placeholder, inputs[1], inputs[2]), output_type="torch", backend_legal_ops=["quant.matmul_rhs_group_quant"], @@ -330,7 +330,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir inputs[2], inputs[3], inputs[4]) - module = torch_mlir.compile( + module = torch_mlir.torchscript.compile( ts_g, ( inputs[0],