Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling non-hardcoded run of our test infrastructure #145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions inc/common/pjrt_implementation/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace tt::pjrt {
class DeviceDescription {

public:
DeviceDescription(int32_t client_id) : client_id_(client_id) {};
DeviceDescription(int32_t client_id)
: client_id_(client_id), device_id_(static_device_id++) {};
~DeviceDescription();
operator PJRT_DeviceDescription *() {
return reinterpret_cast<PJRT_DeviceDescription *>(this);
Expand All @@ -42,14 +43,16 @@ class DeviceDescription {
}

// TODO
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
int64_t device_id() { return 0; }
int64_t device_id() { return device_id_; }

int client_id() { return client_id_; }

int process_index() { return 0; }

private:
static int static_device_id;
int client_id_;
int device_id_;

// TODO We should understand better how these are used.
// See https://github.com/tenstorrent/tt-xla/issues/125
Expand Down
11 changes: 9 additions & 2 deletions inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ class ExecutableImage {

public:
ExecutableImage(std::shared_ptr<void> binary, std::string code,
size_t arg_count, size_t result_count)
size_t arg_count, size_t result_count,
size_t num_addressable_devices)
: ref_count(1), binary(std::move(binary)), code(code),
arg_count(arg_count), result_count(result_count) {}
arg_count(arg_count), result_count(result_count),
num_addressable_devices(num_addressable_devices) {}
operator PJRT_Executable *() {
return reinterpret_cast<PJRT_Executable *>(this);
}
Expand All @@ -49,6 +51,10 @@ class ExecutableImage {

const std::string &get_code() const { return code; }

const size_t get_num_addresible_devices() const {
return num_addressable_devices;
}

private:
// The reference count. Must be disposed when reaching zero.
std::atomic<int> ref_count;
Expand All @@ -61,6 +67,7 @@ class ExecutableImage {

size_t arg_count;
size_t result_count;
size_t num_addressable_devices;
};

} // namespace tt::pjrt
Expand Down
3 changes: 3 additions & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ bool ModuleBuilder::isScalarType(mlir::Type type) {
return false;
}

// Currently hardcoded to one, as we only support one-chip execution.
size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; }

void ModuleBuilder::convertFromSHLOToTTIR(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
// Implicit nesting required to call the stablehlo.composite --> func.call
Expand Down
4 changes: 4 additions & 0 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ModuleBuilder {

bool isOutputScalar(size_t index) const;

// This needs to return the number of addressable devices from the StableHLO
// code.
size_t getNumberOfAddressibleDevices() const;
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

private:
// Creates VHLO module from the input program code.
mlir::OwningOpRef<mlir::ModuleOp>
Expand Down
7 changes: 3 additions & 4 deletions src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ void ClientInstance::BindApi(PJRT_Api *api) {
tt_pjrt_status ClientInstance::PopulateDevices() {
DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices");
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int device_info_count_ =
1; // TODO: revert to chip_ids.size(); once
// https://github.com/tenstorrent/tt-xla/issues/9 is fixed
int device_info_count_ = chip_ids.size();

devices_.resize(device_info_count_);
for (size_t i = 0; i < device_info_count_; ++i) {
Expand Down Expand Up @@ -198,7 +196,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->getNumInputs(),
module_builder_->getNumOutputs()),
module_builder_->getNumOutputs(),
module_builder_->getNumberOfAddressibleDevices()),
addressable_devices_);
*out_executable = executable.release();

Expand Down
2 changes: 2 additions & 0 deletions src/common/pjrt_implementation/device_description.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace tt::pjrt {

DeviceDescription::~DeviceDescription() = default;

int DeviceDescription::static_device_id = 0;

void DeviceDescription::BindApi(PJRT_Api *api) {
DLOG_F(LOG_DEBUG, "DeviceDescription::BindApi");
api->PJRT_DeviceDescription_Id =
Expand Down
23 changes: 18 additions & 5 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "common/pjrt_implementation/loaded_executable_instance.h"

#include <unordered_set>

#include "common/pjrt_implementation/buffer_instance.h"
#include "common/pjrt_implementation/client_instance.h"
#include "common/pjrt_implementation/error_instance.h"
Expand All @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) {
DLOG_F(
LOG_DEBUG,
"LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices");
const std::vector<DeviceInstance *> &devices =
auto &addressable_devices =
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
LoadedExecutableInstance::Unwrap(args->executable)
->addressable_devices();
int num_addressable_devices =
LoadedExecutableInstance::Unwrap(args->executable)
->image_->get_num_addresible_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: addressable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add assert num_addressable_devices == addressable_devices.size().

args->addressable_devices = const_cast<PJRT_Device **>(
reinterpret_cast<PJRT_Device *const *>(devices.data()));
args->num_addressable_devices = devices.size();
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
args->num_addressable_devices = num_addressable_devices;
return nullptr;
};
api->PJRT_LoadedExecutable_Delete =
Expand Down Expand Up @@ -77,8 +82,6 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int dev_0 = chip_ids[0];
tt::runtime::Device device = tt::runtime::openDevice({dev_0});

assert(args->num_devices == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between num_devices and num_addressable_devices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_devices, passed as the argument of this function is the total number of devices for execution. It was inferred from num_addressable_devices in the xla bridge code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inferred how? Equal to it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, equal to it

int dev_index = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have only one process that executes, which is what this index represents.

Expand All @@ -87,13 +90,23 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
std::vector<tt::runtime::Tensor> rt_inputs;
rt_inputs.reserve(args->num_args);

std::unordered_set<int> device_ids;

for (size_t i = 0; i < args->num_args; ++i) {
BufferInstance *buffer =
BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
device_ids.insert(
chip_ids[buffer->device().device_description()->device_id()]);
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}

assert(device_ids.size() == 1);
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

std::vector<int> device_ids_vector(device_ids.begin(), device_ids.end());

tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector);

std::vector<tt::runtime::Tensor> rt_outputs =
tt::runtime::submit(device, binary, 0, rt_inputs);
std::vector<tt::runtime::TensorDesc> output_specs =
Expand Down
15 changes: 11 additions & 4 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def is_initialized(self) -> bool:

return False

def connect_tt_device(self) -> jax.Device:
def connect_tt_device(self, num_device: int = 0) -> jax.Device:
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
"""Returns TTDevice handle."""
return self.connect_device(DeviceType.TT)
return self.connect_device(DeviceType.TT, num_device)

def connect_cpu(self) -> jax.Device:
"""Returns CPUDevice handle."""
Expand All @@ -81,9 +81,16 @@ def connect_gpu(self) -> jax.Device:
"""Returns GPUDevice handle."""
return self.connect_device(DeviceType.GPU)

def connect_device(self, device_type: DeviceType) -> jax.Device:
def _number_of_devices(self, device_type: DeviceType) -> int:
"""Returns the number of devices of specifed type."""
return len(jax.devices(device_type.value))
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

def connect_device(
self, device_type: DeviceType, num_device: int = 0
) -> jax.Device:
"""Returns handle for device identified by `device_type`."""
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
return jax.devices(device_type.value)[0]
assert num_device < self._number_of_devices(device_type)
return jax.devices(device_type.value)[num_device]

def _supported_devices(self) -> Sequence[DeviceType]:
"""Returns list of supported device types."""
Expand Down
20 changes: 12 additions & 8 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class DeviceRunner:
"""

@staticmethod
def run_on_tt_device(workload: Workload) -> Tensor:
def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_on_device(DeviceType.TT, workload)
return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device)

@staticmethod
def run_on_cpu(workload: Workload) -> Tensor:
Expand All @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_on_tt_device(workload: Workload) -> Workload:
def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload:
"""Puts `workload` on TT device."""
return DeviceRunner._put_on_device(DeviceType.TT, workload)
return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device)

@staticmethod
def put_on_cpu(workload: Workload) -> Workload:
Expand Down Expand Up @@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
def _run_on_device(
device_type: DeviceType, workload: Workload, num_device: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)
device_workload = DeviceRunner._put_on_device(device_type, workload, num_device)
device = device_connector.connect_device(device_type)
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
def _put_on_device(
device_type: DeviceType, workload: Workload, num_device: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rearrange a bit to keep device_type and device_num next to each other.

) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type)
device = device_connector.connect_device(device_type, num_device)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
Expand Down
Loading