Skip to content

Commit

Permalink
Changed test api
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Jan 10, 2025
1 parent 2e5c6d5 commit ccd0cf2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
13 changes: 9 additions & 4 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def is_initialized(self) -> bool:

return False

def connect_tt_device(self) -> jax.Device:
def connect_tt_device(self, num_device: int = 0) -> jax.Device:
"""Returns TTDevice handle."""
return self.connect_device(DeviceType.TT)
return self.connect_device(DeviceType.TT, num_device)

def connect_cpu(self) -> jax.Device:
"""Returns CPUDevice handle."""
Expand All @@ -81,9 +81,14 @@ def connect_gpu(self) -> jax.Device:
"""Returns GPUDevice handle."""
return self.connect_device(DeviceType.GPU)

def connect_device(self, device_type: DeviceType) -> jax.Device:
def _number_of_devices(self, device_type: DeviceType) -> int:
"""Returns the number of devices of specifed type."""
return len(jax.devices(device_type.value))

def connect_device(self, device_type: DeviceType, num_device: int = 0) -> jax.Device:
"""Returns handle for device identified by `device_type`."""
return jax.devices(device_type.value)[0]
assert (num_device < self._number_of_devices(device_type))
return jax.devices(device_type.value)[num_device]

def _supported_devices(self) -> Sequence[DeviceType]:
"""Returns list of supported device types."""
Expand Down
16 changes: 8 additions & 8 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class DeviceRunner:
"""

@staticmethod
def run_on_tt_device(workload: Workload) -> Tensor:
def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_on_device(DeviceType.TT, workload)
return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device)

@staticmethod
def run_on_cpu(workload: Workload) -> Tensor:
Expand All @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_on_tt_device(workload: Workload) -> Workload:
def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload:
"""Puts `workload` on TT device."""
return DeviceRunner._put_on_device(DeviceType.TT, workload)
return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device)

@staticmethod
def put_on_cpu(workload: Workload) -> Workload:
Expand Down Expand Up @@ -64,18 +64,18 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
def _run_on_device(device_type: DeviceType, workload: Workload, num_device: int = 0) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)
device_workload = DeviceRunner._put_on_device(device_type, workload, num_device)
device = device_connector.connect_device(device_type)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
def _put_on_device(device_type: DeviceType, workload: Workload, num_device: int = 0) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type)
device = device_connector.connect_device(device_type, num_device)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
Expand Down

0 comments on commit ccd0cf2

Please sign in to comment.