diff --git a/inc/common/pjrt_implementation/device_description.h b/inc/common/pjrt_implementation/device_description.h index 3139a53..d2a4c38 100644 --- a/inc/common/pjrt_implementation/device_description.h +++ b/inc/common/pjrt_implementation/device_description.h @@ -20,7 +20,8 @@ namespace tt::pjrt { class DeviceDescription { public: - DeviceDescription(int32_t client_id) : client_id_(client_id) {}; + DeviceDescription(int32_t client_id) + : client_id_(client_id), device_id_(static_device_id++) {}; ~DeviceDescription(); operator PJRT_DeviceDescription *() { return reinterpret_cast(this); @@ -42,14 +43,16 @@ class DeviceDescription { } // TODO - int64_t device_id() { return 0; } + int64_t device_id() { return device_id_; } int client_id() { return client_id_; } int process_index() { return 0; } private: + static int static_device_id; int client_id_; + int device_id_; // TODO We should understand better how these are used. // See https://github.com/tenstorrent/tt-xla/issues/125 diff --git a/inc/common/pjrt_implementation/executable_image.h b/inc/common/pjrt_implementation/executable_image.h index 319ca6e..b18cd6f 100644 --- a/inc/common/pjrt_implementation/executable_image.h +++ b/inc/common/pjrt_implementation/executable_image.h @@ -23,9 +23,11 @@ class ExecutableImage { public: ExecutableImage(std::shared_ptr binary, std::string code, - size_t arg_count, size_t result_count) + size_t arg_count, size_t result_count, + size_t num_addressable_devices) : ref_count(1), binary(std::move(binary)), code(code), - arg_count(arg_count), result_count(result_count) {} + arg_count(arg_count), result_count(result_count), + num_addressable_devices(num_addressable_devices) {} operator PJRT_Executable *() { return reinterpret_cast(this); } @@ -49,6 +51,10 @@ class ExecutableImage { const std::string &get_code() const { return code; } + const size_t get_num_addresible_devices() const { + return num_addressable_devices; + } + private: // The reference count. Must be disposed when reaching zero. std::atomic ref_count; @@ -61,6 +67,7 @@ class ExecutableImage { size_t arg_count; size_t result_count; + size_t num_addressable_devices; }; } // namespace tt::pjrt diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index f39276b..df3c0c5 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -174,6 +174,9 @@ bool ModuleBuilder::isScalarType(mlir::Type type) { return false; } +// Currently hardcoded to one, as we only support one-chip execution. +size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; } + void ModuleBuilder::convertFromSHLOToTTIR( mlir::OwningOpRef &mlir_module) { // Implicit nesting required to call the stablehlo.composite --> func.call diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 9ea1f57..c37fc44 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -35,6 +35,10 @@ class ModuleBuilder { bool isOutputScalar(size_t index) const; + // This needs to return the number of addressable devices from the StableHLO + // code. + size_t getNumberOfAddressibleDevices() const; + private: // Creates VHLO module from the input program code. mlir::OwningOpRef diff --git a/src/common/pjrt_implementation/client_instance.cc b/src/common/pjrt_implementation/client_instance.cc index 36bb29d..1bc2dee 100644 --- a/src/common/pjrt_implementation/client_instance.cc +++ b/src/common/pjrt_implementation/client_instance.cc @@ -164,9 +164,7 @@ void ClientInstance::BindApi(PJRT_Api *api) { tt_pjrt_status ClientInstance::PopulateDevices() { DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int device_info_count_ = - 1; // TODO: revert to chip_ids.size(); once - // https://github.com/tenstorrent/tt-xla/issues/9 is fixed + int device_info_count_ = chip_ids.size(); devices_.resize(device_info_count_); for (size_t i = 0; i < device_info_count_; ++i) { @@ -198,7 +196,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, new ExecutableImage(module_builder_->getBinary(), std::string(program->code, program->code_size), module_builder_->getNumInputs(), - module_builder_->getNumOutputs()), + module_builder_->getNumOutputs(), + module_builder_->getNumberOfAddressibleDevices()), addressable_devices_); *out_executable = executable.release(); diff --git a/src/common/pjrt_implementation/device_description.cc b/src/common/pjrt_implementation/device_description.cc index a2a6296..c7931f4 100644 --- a/src/common/pjrt_implementation/device_description.cc +++ b/src/common/pjrt_implementation/device_description.cc @@ -16,6 +16,8 @@ namespace tt::pjrt { DeviceDescription::~DeviceDescription() = default; +int DeviceDescription::static_device_id = 0; + void DeviceDescription::BindApi(PJRT_Api *api) { DLOG_F(LOG_DEBUG, "DeviceDescription::BindApi"); api->PJRT_DeviceDescription_Id = diff --git a/src/common/pjrt_implementation/loaded_executable_instance.cc b/src/common/pjrt_implementation/loaded_executable_instance.cc index e8a6e72..d981622 100644 --- a/src/common/pjrt_implementation/loaded_executable_instance.cc +++ b/src/common/pjrt_implementation/loaded_executable_instance.cc @@ -10,6 +10,8 @@ #include "common/pjrt_implementation/loaded_executable_instance.h" +#include + #include "common/pjrt_implementation/buffer_instance.h" #include "common/pjrt_implementation/client_instance.h" #include "common/pjrt_implementation/error_instance.h" @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { DLOG_F( LOG_DEBUG, "LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices"); - const std::vector &devices = + auto &addressable_devices = LoadedExecutableInstance::Unwrap(args->executable) ->addressable_devices(); + int num_addressable_devices = + LoadedExecutableInstance::Unwrap(args->executable) + ->image_->get_num_addresible_devices(); args->addressable_devices = const_cast( - reinterpret_cast(devices.data())); - args->num_addressable_devices = devices.size(); + reinterpret_cast(addressable_devices.data())); + args->num_addressable_devices = num_addressable_devices; return nullptr; }; api->PJRT_LoadedExecutable_Delete = @@ -77,8 +82,6 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int dev_0 = chip_ids[0]; - tt::runtime::Device device = tt::runtime::openDevice({dev_0}); assert(args->num_devices == 1); int dev_index = 0; @@ -87,13 +90,23 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { std::vector rt_inputs; rt_inputs.reserve(args->num_args); + std::unordered_set device_ids; + for (size_t i = 0; i < args->num_args; ++i) { BufferInstance *buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]); rt_inputs.emplace_back(buffer->tensor()); + device_ids.insert( + chip_ids[buffer->device().device_description()->device_id()]); DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); } + assert(device_ids.size() == 1); + + std::vector device_ids_vector(device_ids.begin(), device_ids.end()); + + tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector); + std::vector rt_outputs = tt::runtime::submit(device, binary, 0, rt_inputs); std::vector output_specs =