diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index 3139a53..e103410 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -20,7 +20,8 @@ namespace tt::pjrt { class DeviceDescription { public: - DeviceDescription(int32_t client_id) : client_id_(client_id) {}; + DeviceDescription(int32_t client_id) + : client_id_(client_id), device_id_(static_device_id++) {}; ~DeviceDescription(); operator PJRT_DeviceDescription *() { return reinterpret_cast(this); @@ -41,15 +42,16 @@ class DeviceDescription { return user_string_; } - // TODO - int64_t device_id() { return 0; } + int64_t device_id() { return device_id_; } int client_id() { return client_id_; } int process_index() { return 0; } private: + static int static_device_id; int client_id_; + int device_id_; // TODO We should understand better how these are used. // See https://github.com/tenstorrent/tt-xla/issues/125 diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index 319ca6e..abf21b6 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -23,9 +23,11 @@ class ExecutableImage { public: ExecutableImage(std::shared_ptr binary, std::string code, - size_t arg_count, size_t result_count) + size_t arg_count, size_t result_count, + size_t num_addressable_devices) : ref_count(1), binary(std::move(binary)), code(code), - arg_count(arg_count), result_count(result_count) {} + arg_count(arg_count), result_count(result_count), + num_addressable_devices(num_addressable_devices) {} operator PJRT_Executable *() { return reinterpret_cast(this); } @@ -49,6 +51,10 @@ class ExecutableImage { const std::string &get_code() const { return code; } + const size_t get_num_addressable_devices() const { + return num_addressable_devices; + } + private: // The reference count. Must be disposed when reaching zero. std::atomic ref_count; @@ -61,6 +67,7 @@ class ExecutableImage { size_t arg_count; size_t result_count; + size_t num_addressable_devices; }; } // namespace tt::pjrt diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 9ea1f57..605dbc2 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -35,6 +35,10 @@ class ModuleBuilder { bool isOutputScalar(size_t index) const; + // This needs to return the number of addressable devices from the StableHLO + // code. Currently hardcoded to one, as we only support one-chip execution. + size_t getNumAddressableDevices() const { return 1; } + private: // Creates VHLO module from the input program code. mlir::OwningOpRef diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 36bb29d..d7a2331 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -164,9 +164,7 @@ void ClientInstance::BindApi(PJRT_Api *api) { tt_pjrt_status ClientInstance::PopulateDevices() { DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int device_info_count_ = - 1; // TODO: revert to chip_ids.size(); once - // https://github.com/tenstorrent/tt-xla/issues/9 is fixed + int device_info_count_ = chip_ids.size(); devices_.resize(device_info_count_); for (size_t i = 0; i < device_info_count_; ++i) { @@ -198,7 +196,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, new ExecutableImage(module_builder_->getBinary(), std::string(program->code, program->code_size), module_builder_->getNumInputs(), - module_builder_->getNumOutputs()), + module_builder_->getNumOutputs(), + module_builder_->getNumAddressableDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/device_description.cc b/src/common/pjrt_implementation/device_description.cc index a2a6296..c7931f4 100644 --- a/src/common/pjrt_implementation/device_description.cc +++ b/src/common/pjrt_implementation/device_description.cc @@ -16,6 +16,8 @@ namespace tt::pjrt { DeviceDescription::~DeviceDescription() = default; +int DeviceDescription::static_device_id = 0; + void DeviceDescription::BindApi(PJRT_Api *api) { DLOG_F(LOG_DEBUG, "DeviceDescription::BindApi"); api->PJRT_DeviceDescription_Id = diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index e8a6e72..ae4483a 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -10,6 +10,8 @@ #include "common/pjrt_implementation/loaded_executable_instance.h" +#include + #include "common/pjrt_implementation/buffer_instance.h" #include "common/pjrt_implementation/client_instance.h" #include "common/pjrt_implementation/error_instance.h" @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - const std::vector &devices = + const std::vector &addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) ->addressable_devices(); + int num_addressable_devices = + LoadedExecutableInstance::Unwrap(args->executable) + ->image_->get_num_addressable_devices(); args->addressable_devices = const_cast( - reinterpret_cast(devices.data())); - args->num_addressable_devices = devices.size(); + reinterpret_cast(addressable_devices.data())); + args->num_addressable_devices = num_addressable_devices; return nullptr; }; api->PJRT_LoadedExecutable_Delete = @@ -77,23 +82,34 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int dev_0 = chip_ids[0]; - tt::runtime::Device device = tt::runtime::openDevice({dev_0}); + // Sanity check, as we only support execution on one chip currently. assert(args->num_devices == 1); + int dev_index = 0; tt::runtime::Binary binary(image_->get_binary()); std::vector rt_inputs; rt_inputs.reserve(args->num_args); + std::unordered_set device_ids; + for (size_t i = 0; i < args->num_args; ++i) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); + int64_t buffer_device_id = + buffer->device().device_description()->device_id(); + device_ids.insert(chip_ids[buffer_device_id]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } + assert(device_ids.size() == 1); + + std::vector device_ids_vector(device_ids.begin(), device_ids.end()); + + tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector); + std::vector rt_outputs = tt::runtime::submit(device, binary, 0, rt_inputs); std::vector output_specs = diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 3575094..6384da0 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -69,9 +69,9 @@ def is_initialized(self) -> bool: return False - def connect_tt_device(self) -> jax.Device: + def connect_tt_device(self, device_num: int = 0) -> jax.Device: """Returns TTDevice handle.""" - return self.connect_device(DeviceType.TT) + return self.connect_device(DeviceType.TT, device_num) def connect_cpu(self) -> jax.Device: """Returns CPUDevice handle.""" @@ -81,9 +81,22 @@ def connect_gpu(self) -> jax.Device: """Returns GPUDevice handle.""" return self.connect_device(DeviceType.GPU) - def connect_device(self, device_type: DeviceType) -> jax.Device: - """Returns handle for device identified by `device_type`.""" - return jax.devices(device_type.value)[0] + def connect_device( + self, device_type: DeviceType, device_num: int = 0 + ) -> jax.Device: + """ + Returns handle for device identified by `device_type`. + + If there are multiple available devices of `device_type`, `device_num` makes it + possible to choose between them. By default, returns first available device. + """ + assert device_num < self._number_of_devices(device_type) + + return jax.devices(device_type.value)[device_num] + + def _number_of_devices(self, device_type: DeviceType) -> int: + """Returns the number of available devices of specified type.""" + return len(jax.devices(device_type.value)) def _supported_devices(self) -> Sequence[DeviceType]: """Returns list of supported device types.""" diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 76349ec..28b6134 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -19,9 +19,9 @@ class DeviceRunner: """ @staticmethod - def run_on_tt_device(workload: Workload) -> Tensor: + def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: """Runs `workload` on TT device.""" - return DeviceRunner._run_on_device(DeviceType.TT, workload) + return DeviceRunner._run_on_device(DeviceType.TT, workload, device_num) @staticmethod def run_on_cpu(workload: Workload) -> Tensor: @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(workload: Workload) -> Workload: + def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload: """Puts `workload` on TT device.""" - return DeviceRunner._put_on_device(DeviceType.TT, workload) + return DeviceRunner._put_on_device(DeviceType.TT, workload, device_num) @staticmethod def put_on_cpu(workload: Workload) -> Workload: @@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: + def _run_on_device( + device_type: DeviceType, workload: Workload, device_num: int = 0 + ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(device_type, workload) - device = device_connector.connect_device(device_type) + device_workload = DeviceRunner._put_on_device(device_type, workload, device_num) + device = device_connector.connect_device(device_type, device_num) with jax.default_device(device): return device_workload.execute() @staticmethod - def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: + def _put_on_device( + device_type: DeviceType, workload: Workload, device_num: int = 0 + ) -> Workload: """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type) + device = device_connector.connect_device(device_type, device_num) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod