From 112e32aed13a42f805c0c168779c62eed5ac3e5a Mon Sep 17 00:00:00 2001 From: ajakovljevicTT Date: Fri, 10 Jan 2025 14:24:56 +0000 Subject: [PATCH] Addressed comments --- .../pjrt_implementation/device_description.h | 1 - .../pjrt_implementation/executable_image.h | 2 +- src/common/module_builder.cc | 3 --- src/common/module_builder.h | 4 +-- .../pjrt_implementation/client_instance.cc | 2 +- .../loaded_executable_instance.cc | 9 ++++--- tests/infra/device_connector.py | 26 ++++++++++++------- tests/infra/device_runner.py | 18 ++++++------- 8 files changed, 34 insertions(+), 31 deletions(-) diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index d2a4c38..e103410 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -42,7 +42,6 @@ class DeviceDescription { return user_string_; } - // TODO int64_t device_id() { return device_id_; } int client_id() { return client_id_; } diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index b18cd6f..abf21b6 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -51,7 +51,7 @@ class ExecutableImage { const std::string &get_code() const { return code; } - const size_t get_num_addresible_devices() const { + const size_t get_num_addressable_devices() const { return num_addressable_devices; } diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index df3c0c5..f39276b 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -174,9 +174,6 @@ bool ModuleBuilder::isScalarType(mlir::Type type) { return false; } -// Currently hardcoded to one, as we only support one-chip execution. -size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; } - void ModuleBuilder::convertFromSHLOToTTIR( mlir::OwningOpRef &mlir_module) { // Implicit nesting required to call the stablehlo.composite --> func.call diff --git a/src/common/module_builder.h b/src/common/module_builder.h index c37fc44..605dbc2 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -36,8 +36,8 @@ class ModuleBuilder { bool isOutputScalar(size_t index) const; // This needs to return the number of addressable devices from the StableHLO - // code. - size_t getNumberOfAddressibleDevices() const; + // code. Currently hardcoded to one, as we only support one-chip execution. + size_t getNumAddressableDevices() const { return 1; } private: // Creates VHLO module from the input program code. diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 1bc2dee..d7a2331 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -197,7 +197,7 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, std::string(program->code, program->code_size), module_builder_->getNumInputs(), module_builder_->getNumOutputs(), - module_builder_->getNumberOfAddressibleDevices()), + module_builder_->getNumAddressableDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index d981622..4901b84 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -34,12 +34,12 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - auto &addressable_devices = + const std::vector &addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) ->addressable_devices(); int num_addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) - ->image_->get_num_addresible_devices(); + ->image_->get_num_addressable_devices(); args->addressable_devices = const_cast( reinterpret_cast(addressable_devices.data())); args->num_addressable_devices = num_addressable_devices; @@ -96,8 +96,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); - device_ids.insert( - chip_ids[buffer->device().device_description()->device_id()]); + int64_t buffer_device_id = + buffer->device().device_description()->device_id(); + device_ids.insert(chip_ids[buffer_device_id]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index dcdafe9..6384da0 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -69,9 +69,9 @@ def is_initialized(self) -> bool: return False - def connect_tt_device(self, num_device: int = 0) -> jax.Device: + def connect_tt_device(self, device_num: int = 0) -> jax.Device: """Returns TTDevice handle.""" - return self.connect_device(DeviceType.TT, num_device) + return self.connect_device(DeviceType.TT, device_num) def connect_cpu(self) -> jax.Device: """Returns CPUDevice handle.""" @@ -81,16 +81,22 @@ def connect_gpu(self) -> jax.Device: """Returns GPUDevice handle.""" return self.connect_device(DeviceType.GPU) - def _number_of_devices(self, device_type: DeviceType) -> int: - """Returns the number of devices of specifed type.""" - return len(jax.devices(device_type.value)) - def connect_device( - self, device_type: DeviceType, num_device: int = 0 + self, device_type: DeviceType, device_num: int = 0 ) -> jax.Device: - """Returns handle for device identified by `device_type`.""" - assert num_device < self._number_of_devices(device_type) - return jax.devices(device_type.value)[num_device] + """ + Returns handle for device identified by `device_type`. + + If there are multiple available devices of `device_type`, `device_num` makes it + possible to choose between them. By default, returns first available device. + """ + assert device_num < self._number_of_devices(device_type) + + return jax.devices(device_type.value)[device_num] + + def _number_of_devices(self, device_type: DeviceType) -> int: + """Returns the number of available devices of specified type.""" + return len(jax.devices(device_type.value)) def _supported_devices(self) -> Sequence[DeviceType]: """Returns list of supported device types.""" diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index efc07d2..28b6134 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -19,9 +19,9 @@ class DeviceRunner: """ @staticmethod - def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor: + def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: """Runs `workload` on TT device.""" - return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device) + return DeviceRunner._run_on_device(DeviceType.TT, workload, device_num) @staticmethod def run_on_cpu(workload: Workload) -> Tensor: @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload: + def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload: """Puts `workload` on TT device.""" - return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device) + return DeviceRunner._put_on_device(DeviceType.TT, workload, device_num) @staticmethod def put_on_cpu(workload: Workload) -> Workload: @@ -65,21 +65,21 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: @staticmethod def _run_on_device( - device_type: DeviceType, workload: Workload, num_device: int = 0 + device_type: DeviceType, workload: Workload, device_num: int = 0 ) -> Tensor: """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(device_type, workload, num_device) - device = device_connector.connect_device(device_type) + device_workload = DeviceRunner._put_on_device(device_type, workload, device_num) + device = device_connector.connect_device(device_type, device_num) with jax.default_device(device): return device_workload.execute() @staticmethod def _put_on_device( - device_type: DeviceType, workload: Workload, num_device: int = 0 + device_type: DeviceType, workload: Workload, device_num: int = 0 ) -> Workload: """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type, num_device) + device = device_connector.connect_device(device_type, device_num) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod