From 2e5c6d5a768f7dafac9eec07a4c1417ac4f2ade5 Mon Sep 17 00:00:00 2001 From: ajakovljevicTT Date: Fri, 10 Jan 2025 10:08:12 +0000 Subject: [PATCH 1/3] Fixed hardcoding od x2 chips --- .../pjrt_implementation/device_description.h | 7 ++++-- .../pjrt_implementation/executable_image.h | 11 +++++++-- src/common/module_builder.cc | 3 +++ src/common/module_builder.h | 4 ++++ .../pjrt_implementation/client_instance.cc | 7 +++--- .../pjrt_implementation/device_description.cc | 2 ++ .../loaded_executable_instance.cc | 23 +++++++++++++++---- 7 files changed, 44 insertions(+), 13 deletions(-) diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index 3139a53..d2a4c38 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -20,7 +20,8 @@ namespace tt::pjrt { class DeviceDescription { public: - DeviceDescription(int32_t client_id) : client_id_(client_id) {}; + DeviceDescription(int32_t client_id) + : client_id_(client_id), device_id_(static_device_id++) {}; ~DeviceDescription(); operator PJRT_DeviceDescription *() { return reinterpret_cast(this); @@ -42,14 +43,16 @@ class DeviceDescription { } // TODO - int64_t device_id() { return 0; } + int64_t device_id() { return device_id_; } int client_id() { return client_id_; } int process_index() { return 0; } private: + static int static_device_id; int client_id_; + int device_id_; // TODO We should understand better how these are used. // See https://github.com/tenstorrent/tt-xla/issues/125 diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index 319ca6e..b18cd6f 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -23,9 +23,11 @@ class ExecutableImage { public: ExecutableImage(std::shared_ptr binary, std::string code, - size_t arg_count, size_t result_count) + size_t arg_count, size_t result_count, + size_t num_addressable_devices) : ref_count(1), binary(std::move(binary)), code(code), - arg_count(arg_count), result_count(result_count) {} + arg_count(arg_count), result_count(result_count), + num_addressable_devices(num_addressable_devices) {} operator PJRT_Executable *() { return reinterpret_cast(this); } @@ -49,6 +51,10 @@ class ExecutableImage { const std::string &get_code() const { return code; } + const size_t get_num_addresible_devices() const { + return num_addressable_devices; + } + private: // The reference count. Must be disposed when reaching zero. std::atomic ref_count; @@ -61,6 +67,7 @@ class ExecutableImage { size_t arg_count; size_t result_count; + size_t num_addressable_devices; }; } // namespace tt::pjrt diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index f39276b..df3c0c5 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -174,6 +174,9 @@ bool ModuleBuilder::isScalarType(mlir::Type type) { return false; } +// Currently hardcoded to one, as we only support one-chip execution. +size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; } + void ModuleBuilder::convertFromSHLOToTTIR( mlir::OwningOpRef &mlir_module) { // Implicit nesting required to call the stablehlo.composite --> func.call diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 9ea1f57..c37fc44 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -35,6 +35,10 @@ class ModuleBuilder { bool isOutputScalar(size_t index) const; + // This needs to return the number of addressable devices from the StableHLO + // code. + size_t getNumberOfAddressibleDevices() const; + private: // Creates VHLO module from the input program code. mlir::OwningOpRef diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 36bb29d..1bc2dee 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -164,9 +164,7 @@ void ClientInstance::BindApi(PJRT_Api *api) { tt_pjrt_status ClientInstance::PopulateDevices() { DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int device_info_count_ = - 1; // TODO: revert to chip_ids.size(); once - // https://github.com/tenstorrent/tt-xla/issues/9 is fixed + int device_info_count_ = chip_ids.size(); devices_.resize(device_info_count_); for (size_t i = 0; i < device_info_count_; ++i) { @@ -198,7 +196,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, new ExecutableImage(module_builder_->getBinary(), std::string(program->code, program->code_size), module_builder_->getNumInputs(), - module_builder_->getNumOutputs()), + module_builder_->getNumOutputs(), + module_builder_->getNumberOfAddressibleDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/device_description.cc b/src/common/pjrt_implementation/device_description.cc index a2a6296..c7931f4 100644 --- a/src/common/pjrt_implementation/device_description.cc +++ b/src/common/pjrt_implementation/device_description.cc @@ -16,6 +16,8 @@ namespace tt::pjrt { DeviceDescription::~DeviceDescription() = default; +int DeviceDescription::static_device_id = 0; + void DeviceDescription::BindApi(PJRT_Api *api) { DLOG_F(LOG_DEBUG, "DeviceDescription::BindApi"); api->PJRT_DeviceDescription_Id = diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index e8a6e72..d981622 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -10,6 +10,8 @@ #include "common/pjrt_implementation/loaded_executable_instance.h" +#include + #include "common/pjrt_implementation/buffer_instance.h" #include "common/pjrt_implementation/client_instance.h" #include "common/pjrt_implementation/error_instance.h" @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - const std::vector &devices = + auto &addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) ->addressable_devices(); + int num_addressable_devices = + LoadedExecutableInstance::Unwrap(args->executable) + ->image_->get_num_addresible_devices(); args->addressable_devices = const_cast( - reinterpret_cast(devices.data())); - args->num_addressable_devices = devices.size(); + reinterpret_cast(addressable_devices.data())); + args->num_addressable_devices = num_addressable_devices; return nullptr; }; api->PJRT_LoadedExecutable_Delete = @@ -77,8 +82,6 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int dev_0 = chip_ids[0]; - tt::runtime::Device device = tt::runtime::openDevice({dev_0}); assert(args->num_devices == 1); int dev_index = 0; @@ -87,13 +90,23 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { std::vector rt_inputs; rt_inputs.reserve(args->num_args); + std::unordered_set device_ids; + for (size_t i = 0; i < args->num_args; ++i) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); + device_ids.insert( + chip_ids[buffer->device().device_description()->device_id()]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } + assert(device_ids.size() == 1); + + std::vector device_ids_vector(device_ids.begin(), device_ids.end()); + + tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector); + std::vector rt_outputs = tt::runtime::submit(device, binary, 0, rt_inputs); std::vector output_specs = From 418f9d4e3a3150d868e924ab37f455dce72d5372 Mon Sep 17 00:00:00 2001 From: ajakovljevicTT Date: Fri, 10 Jan 2025 12:25:48 +0000 Subject: [PATCH 2/3] Changed test api --- tests/infra/device_connector.py | 15 +++++++++++---- tests/infra/device_runner.py | 20 ++++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 3575094..dcdafe9 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -69,9 +69,9 @@ def is_initialized(self) -> bool: return False - def connect_tt_device(self) -> jax.Device: + def connect_tt_device(self, num_device: int = 0) -> jax.Device: """Returns TTDevice handle.""" - return self.connect_device(DeviceType.TT) + return self.connect_device(DeviceType.TT, num_device) def connect_cpu(self) -> jax.Device: """Returns CPUDevice handle.""" @@ -81,9 +81,16 @@ def connect_gpu(self) -> jax.Device: """Returns GPUDevice handle.""" return self.connect_device(DeviceType.GPU) - def connect_device(self, device_type: DeviceType) -> jax.Device: + def _number_of_devices(self, device_type: DeviceType) -> int: + """Returns the number of devices of specifed type.""" + return len(jax.devices(device_type.value)) + + def connect_device( + self, device_type: DeviceType, num_device: int = 0 + ) -> jax.Device: """Returns handle for device identified by `device_type`.""" - return jax.devices(device_type.value)[0] + assert num_device < self._number_of_devices(device_type) + return jax.devices(device_type.value)[num_device] def _supported_devices(self) -> Sequence[DeviceType]: """Returns list of supported device types.""" diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 76349ec..efc07d2 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -19,9 +19,9 @@ class DeviceRunner: """ @staticmethod - def run_on_tt_device(workload: Workload) -> Tensor: + def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor: """Runs `workload` on TT device.""" - return DeviceRunner._run_on_device(DeviceType.TT, workload) + return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device) @staticmethod def run_on_cpu(workload: Workload) -> Tensor: @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(workload: Workload) -> Workload: + def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload: """Puts `workload` on TT device.""" - return DeviceRunner._put_on_device(DeviceType.TT, workload) + return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device) @staticmethod def put_on_cpu(workload: Workload) -> Workload: @@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: + def _run_on_device( + device_type: DeviceType, workload: Workload, num_device: int = 0 + ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(device_type, workload) + device_workload = DeviceRunner._put_on_device(device_type, workload, num_device) device = device_connector.connect_device(device_type) with jax.default_device(device): return device_workload.execute() @staticmethod - def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: + def _put_on_device( + device_type: DeviceType, workload: Workload, num_device: int = 0 + ) -> Workload: """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type) + device = device_connector.connect_device(device_type, num_device) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod From 680a3eab1b130910e943ff023879a1d27464c0d9 Mon Sep 17 00:00:00 2001 From: ajakovljevicTT Date: Fri, 10 Jan 2025 14:24:56 +0000 Subject: [PATCH 3/3] Addressed comments --- .../pjrt_implementation/device_description.h | 1 - .../pjrt_implementation/executable_image.h | 2 +- src/common/module_builder.cc | 3 --- src/common/module_builder.h | 4 +-- .../pjrt_implementation/client_instance.cc | 2 +- .../loaded_executable_instance.cc | 11 +++++--- tests/infra/device_connector.py | 26 ++++++++++++------- tests/infra/device_runner.py | 18 ++++++------- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index d2a4c38..e103410 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -42,7 +42,6 @@ class DeviceDescription { return user_string_; } - // TODO int64_t device_id() { return device_id_; } int client_id() { return client_id_; } diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index b18cd6f..abf21b6 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -51,7 +51,7 @@ class ExecutableImage { const std::string &get_code() const { return code; } - const size_t get_num_addresible_devices() const { + const size_t get_num_addressable_devices() const { return num_addressable_devices; } diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index df3c0c5..f39276b 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -174,9 +174,6 @@ bool ModuleBuilder::isScalarType(mlir::Type type) { return false; } -// Currently hardcoded to one, as we only support one-chip execution. -size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; } - void ModuleBuilder::convertFromSHLOToTTIR( mlir::OwningOpRef &mlir_module) { // Implicit nesting required to call the stablehlo.composite --> func.call diff --git a/src/common/module_builder.h b/src/common/module_builder.h index c37fc44..605dbc2 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -36,8 +36,8 @@ class ModuleBuilder { bool isOutputScalar(size_t index) const; // This needs to return the number of addressable devices from the StableHLO - // code. - size_t getNumberOfAddressibleDevices() const; + // code. Currently hardcoded to one, as we only support one-chip execution. + size_t getNumAddressableDevices() const { return 1; } private: // Creates VHLO module from the input program code. diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 1bc2dee..d7a2331 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -197,7 +197,7 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, std::string(program->code, program->code_size), module_builder_->getNumInputs(), module_builder_->getNumOutputs(), - module_builder_->getNumberOfAddressibleDevices()), + module_builder_->getNumAddressableDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index d981622..ae4483a 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -34,12 +34,12 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - auto &addressable_devices = + const std::vector &addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) ->addressable_devices(); int num_addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) - ->image_->get_num_addresible_devices(); + ->image_->get_num_addressable_devices(); args->addressable_devices = const_cast( reinterpret_cast(addressable_devices.data())); args->num_addressable_devices = num_addressable_devices; @@ -83,7 +83,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); + // Sanity check, as we only support execution on one chip currently. assert(args->num_devices == 1); + int dev_index = 0; tt::runtime::Binary binary(image_->get_binary()); @@ -96,8 +98,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); - device_ids.insert( - chip_ids[buffer->device().device_description()->device_id()]); + int64_t buffer_device_id = + buffer->device().device_description()->device_id(); + device_ids.insert(chip_ids[buffer_device_id]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index dcdafe9..6384da0 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -69,9 +69,9 @@ def is_initialized(self) -> bool: return False - def connect_tt_device(self, num_device: int = 0) -> jax.Device: + def connect_tt_device(self, device_num: int = 0) -> jax.Device: """Returns TTDevice handle.""" - return self.connect_device(DeviceType.TT, num_device) + return self.connect_device(DeviceType.TT, device_num) def connect_cpu(self) -> jax.Device: """Returns CPUDevice handle.""" @@ -81,16 +81,22 @@ def connect_gpu(self) -> jax.Device: """Returns GPUDevice handle.""" return self.connect_device(DeviceType.GPU) - def _number_of_devices(self, device_type: DeviceType) -> int: - """Returns the number of devices of specifed type.""" - return len(jax.devices(device_type.value)) - def connect_device( - self, device_type: DeviceType, num_device: int = 0 + self, device_type: DeviceType, device_num: int = 0 ) -> jax.Device: - """Returns handle for device identified by `device_type`.""" - assert num_device < self._number_of_devices(device_type) - return jax.devices(device_type.value)[num_device] + """ + Returns handle for device identified by `device_type`. + + If there are multiple available devices of `device_type`, `device_num` makes it + possible to choose between them. By default, returns first available device. + """ + assert device_num < self._number_of_devices(device_type) + + return jax.devices(device_type.value)[device_num] + + def _number_of_devices(self, device_type: DeviceType) -> int: + """Returns the number of available devices of specified type.""" + return len(jax.devices(device_type.value)) def _supported_devices(self) -> Sequence[DeviceType]: """Returns list of supported device types.""" diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index efc07d2..28b6134 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -19,9 +19,9 @@ class DeviceRunner: """ @staticmethod - def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor: + def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: """Runs `workload` on TT device.""" - return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device) + return DeviceRunner._run_on_device(DeviceType.TT, workload, device_num) @staticmethod def run_on_cpu(workload: Workload) -> Tensor: @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload: + def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload: """Puts `workload` on TT device.""" - return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device) + return DeviceRunner._put_on_device(DeviceType.TT, workload, device_num) @staticmethod def put_on_cpu(workload: Workload) -> Workload: @@ -65,21 +65,21 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: @staticmethod def _run_on_device( - device_type: DeviceType, workload: Workload, num_device: int = 0 + device_type: DeviceType, workload: Workload, device_num: int = 0 ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(device_type, workload, num_device) - device = device_connector.connect_device(device_type) + device_workload = DeviceRunner._put_on_device(device_type, workload, device_num) + device = device_connector.connect_device(device_type, device_num) with jax.default_device(device): return device_workload.execute() @staticmethod def _put_on_device( - device_type: DeviceType, workload: Workload, num_device: int = 0 + device_type: DeviceType, workload: Workload, device_num: int = 0 ) -> Workload: """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type, num_device) + device = device_connector.connect_device(device_type, device_num) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod