Skip to content

Commit

Permalink
Change run_iree_module_function to accept HAL device instead of drive…
Browse files Browse the repository at this point in the history
…r name (#893)

It did not sit right that run_iree_module_function would create a HAL
device just to make iree.runtime.FunctionInvoker happy.
  • Loading branch information
sogartar authored Jan 31, 2025
1 parent 7671d57 commit 4eac34e
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
args: List[iree.runtime.DeviceArray],
driver: str,
device: iree.runtime.HalDevice,
function_name: str = "main",
trace_path_prefix: Optional[str] = None,
) -> List[iree.runtime.DeviceArray]:
Expand All @@ -154,7 +154,7 @@ def run_iree_module_function(
vm_context=vm_context,
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
# This works, but does not look right.
device=iree.runtime.get_device(driver, cache=False),
device=device,
vm_function=vm_function,
)

Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/clip/clip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name=f"forward_bs{batch_size}",
trace_path_prefix=f"{target_model_path_prefix}_iree_",
)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/flux/flux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def runCompareIreeAgainstTorchEager(
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name=f"forward_bs{batch_size}",
)
)
Expand Down
4 changes: 2 additions & 2 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _(model, *args, **kwargs) -> torch.Tensor:
function_name="prefill",
module=iree_module,
vm_context=vm_context,
driver=iree_driver,
device=iree_devices[0],
trace_path_prefix=path_prefix if dump_enabled else None,
)
prefill_iree_result = UnreducedTensor(ts=iree_to_torch(*prefill_iree_result))
Expand All @@ -367,7 +367,7 @@ def _(model, *args, **kwargs) -> torch.Tensor:
function_name="decode",
module=iree_module,
vm_context=vm_context,
driver=iree_driver,
device=iree_devices[0],
trace_path_prefix=path_prefix if dump_enabled else None,
)
decode_iree_result = UnreducedTensor(ts=iree_to_torch(*decode_iree_result))
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/t5/t5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name=f"forward_bs{batch_size}",
trace_path_prefix=f"{target_model_path_prefix}_iree_",
)
Expand Down
8 changes: 4 additions & 4 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def testVaeIreeVsHuggingFace(self):
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name="decode",
)[0].to_host()
# TODO: Verify these numerics are good or if tolerances are too loose
Expand Down Expand Up @@ -193,7 +193,7 @@ def testVaeIreeVsHuggingFace(self):
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name="decode",
)[0].to_host()
# TODO: Upload IR on passing tests
Expand Down Expand Up @@ -328,7 +328,7 @@ def testVaeIreeVsHuggingFace(self):
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name="decode",
)[0]
)
Expand All @@ -355,7 +355,7 @@ def testVaeIreeVsHuggingFace(self):
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
device=iree_devices[0],
function_name="decode",
)[0]
)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _(model, x):
module=iree_module,
vm_context=iree_vm_context,
args=iree_args,
driver="local-task",
device=iree_devices[0],
function_name=f"forward",
)

Expand Down

0 comments on commit 4eac34e

Please sign in to comment.