Skip to content

Commit

Permalink
Test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Jan 3, 2025
1 parent 65e1df8 commit 2fae952
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ bool ModuleBuilder::isScalarType(mlir::Type type) {
return false;
}

// Currently hardcoded to one, as we only support one-chip execution.
size_t ModuleBuilder::getNumberOfAddressibleDevices() const
{
return 1;
Expand Down
2 changes: 0 additions & 2 deletions src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ tt_pjrt_status ClientInstance::PopulateDevices() {
DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices");
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int device_info_count_ = chip_ids.size();
//1; // TODO: revert to chip_ids.size(); once
// https://github.com/tenstorrent/tt-xla/issues/9 is fixed

devices_.resize(device_info_count_);
for (size_t i = 0; i < device_info_count_; ++i) {
Expand Down
3 changes: 0 additions & 3 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "common/pjrt_implementation/loaded_executable_instance.h"

#include <unordered_set>
#include <iostream>

#include "common/pjrt_implementation/buffer_instance.h"
#include "common/pjrt_implementation/client_instance.h"
Expand Down Expand Up @@ -102,8 +101,6 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {

std::vector<int> device_ids_vector(device_ids.begin(), device_ids.end());

std::cerr << "device_ids_vector=" << device_ids_vector[0] << std::endl;

tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector);

std::vector<tt::runtime::Tensor> rt_outputs =
Expand Down
1 change: 1 addition & 0 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def module_convert(a):
["input_shapes", "required_atol"],
[([(3, 3), (3, 3)], 0.01), ([(3, 3, 3), (3, 3, 3)], 35e-2)],
)
@pytest.mark.skip("Current issue with div op.")
def test_div_op(input_shapes, required_atol):
def module_div(a, b):
return a / b
Expand Down
6 changes: 2 additions & 4 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)
device = device_connector.connect_device(device_type)

with jax.default_device(device):
ret = device_workload.execute()
# if not (isinstance(workload.args[0][0], int)):
# print("value=", device_workload.args[0][0].device)
return ret
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
Expand Down
4 changes: 2 additions & 2 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def verify_module(
required_atol=1e-2,
dtype=jnp.float32,
):
tt_device = jax.devices('tt')[1]
print("tt_device=", tt_device)
tt_device = jax.devices('tt')[0]

cpu_inputs = [
random_input_tensor(input_shapes[i], key + i, dtype=dtype)
for i in range(len(input_shapes))
Expand Down
1 change: 1 addition & 0 deletions tests/jax/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array:
["x_shape", "y_shape"],
[
[(32, 32), (32, 32)],
[(64, 64), (64, 64)],
],
)
def test_add(x_shape: tuple, y_shape: tuple):
Expand Down

0 comments on commit 2fae952

Please sign in to comment.