-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling non-hardcoded run of our test infrastructure #145
base: main
Are you sure you want to change the base?
Conversation
2fae952
to
7f648a4
Compare
895a20e
to
e004717
Compare
@ajakovljevicTT please rebase and ping me again to take a look. That bug should be gone now that old infra is gone. |
e004717
to
2e5c6d5
Compare
ccd0cf2
to
418f9d4
Compare
LoadedExecutableInstance::Unwrap(args->executable) | ||
->addressable_devices(); | ||
int num_addressable_devices = | ||
LoadedExecutableInstance::Unwrap(args->executable) | ||
->image_->get_num_addresible_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: addressable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add assert num_addressable_devices == addressable_devices.size()
.
tests/infra/device_runner.py
Outdated
device = device_connector.connect_device(device_type) | ||
|
||
with jax.default_device(device): | ||
return device_workload.execute() | ||
|
||
@staticmethod | ||
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: | ||
def _put_on_device( | ||
device_type: DeviceType, workload: Workload, num_device: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rearrange a bit to keep device_type
and device_num
next to each other.
tests/infra/device_runner.py
Outdated
@@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: | |||
raise NotImplementedError("Support for GPUs not implemented") | |||
|
|||
@staticmethod | |||
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: | |||
def _run_on_device( | |||
device_type: DeviceType, workload: Workload, num_device: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
112e32a
to
680a3ea
Compare
As per issue #9, our test infra had a problem when detecting two TT devices, as is the case on the two-chip N300, which was workarounded by hardcoding which devices to use. This PR fixes that, instead just setting that the number of devices that the backend uses is capped to 1. Additionally, there was a bug in our old infra which ran the tests on cpu instead of a TT device, so this PR also fixes that.
Fixes #9.