Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling non-hardcoded run of our test infrastructure #145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions inc/common/pjrt_implementation/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace tt::pjrt {
class DeviceDescription {

public:
DeviceDescription(int32_t client_id) : client_id_(client_id) {};
DeviceDescription(int32_t client_id)
: client_id_(client_id), device_id_(static_device_id++) {};
~DeviceDescription();
operator PJRT_DeviceDescription *() {
return reinterpret_cast<PJRT_DeviceDescription *>(this);
Expand All @@ -41,15 +42,16 @@ class DeviceDescription {
return user_string_;
}

// TODO
int64_t device_id() { return 0; }
int64_t device_id() { return device_id_; }

int client_id() { return client_id_; }

int process_index() { return 0; }

private:
static int static_device_id;
int client_id_;
int device_id_;

// TODO We should understand better how these are used.
// See https://github.com/tenstorrent/tt-xla/issues/125
Expand Down
11 changes: 9 additions & 2 deletions inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ class ExecutableImage {

public:
ExecutableImage(std::shared_ptr<void> binary, std::string code,
size_t arg_count, size_t result_count)
size_t arg_count, size_t result_count,
size_t num_addressable_devices)
: ref_count(1), binary(std::move(binary)), code(code),
arg_count(arg_count), result_count(result_count) {}
arg_count(arg_count), result_count(result_count),
num_addressable_devices(num_addressable_devices) {}
operator PJRT_Executable *() {
return reinterpret_cast<PJRT_Executable *>(this);
}
Expand All @@ -49,6 +51,10 @@ class ExecutableImage {

const std::string &get_code() const { return code; }

const size_t get_num_addressable_devices() const {
return num_addressable_devices;
}

private:
// The reference count. Must be disposed when reaching zero.
std::atomic<int> ref_count;
Expand All @@ -61,6 +67,7 @@ class ExecutableImage {

size_t arg_count;
size_t result_count;
size_t num_addressable_devices;
};

} // namespace tt::pjrt
Expand Down
4 changes: 4 additions & 0 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ModuleBuilder {

bool isOutputScalar(size_t index) const;

// This needs to return the number of addressable devices from the StableHLO
// code. Currently hardcoded to one, as we only support one-chip execution.
size_t getNumAddressableDevices() const { return 1; }

private:
// Creates VHLO module from the input program code.
mlir::OwningOpRef<mlir::ModuleOp>
Expand Down
7 changes: 3 additions & 4 deletions src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ void ClientInstance::BindApi(PJRT_Api *api) {
tt_pjrt_status ClientInstance::PopulateDevices() {
DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices");
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int device_info_count_ =
1; // TODO: revert to chip_ids.size(); once
// https://github.com/tenstorrent/tt-xla/issues/9 is fixed
int device_info_count_ = chip_ids.size();

devices_.resize(device_info_count_);
for (size_t i = 0; i < device_info_count_; ++i) {
Expand Down Expand Up @@ -198,7 +196,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->getNumInputs(),
module_builder_->getNumOutputs()),
module_builder_->getNumOutputs(),
module_builder_->getNumAddressableDevices()),
addressable_devices_);
*out_executable = executable.release();

Expand Down
2 changes: 2 additions & 0 deletions src/common/pjrt_implementation/device_description.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace tt::pjrt {

DeviceDescription::~DeviceDescription() = default;

int DeviceDescription::static_device_id = 0;

void DeviceDescription::BindApi(PJRT_Api *api) {
DLOG_F(LOG_DEBUG, "DeviceDescription::BindApi");
api->PJRT_DeviceDescription_Id =
Expand Down
26 changes: 21 additions & 5 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "common/pjrt_implementation/loaded_executable_instance.h"

#include <unordered_set>

#include "common/pjrt_implementation/buffer_instance.h"
#include "common/pjrt_implementation/client_instance.h"
#include "common/pjrt_implementation/error_instance.h"
Expand All @@ -32,12 +34,15 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) {
DLOG_F(
LOG_DEBUG,
"LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices");
const std::vector<DeviceInstance *> &devices =
const std::vector<DeviceInstance *> &addressable_devices =
LoadedExecutableInstance::Unwrap(args->executable)
->addressable_devices();
int num_addressable_devices =
LoadedExecutableInstance::Unwrap(args->executable)
->image_->get_num_addressable_devices();
args->addressable_devices = const_cast<PJRT_Device **>(
reinterpret_cast<PJRT_Device *const *>(devices.data()));
args->num_addressable_devices = devices.size();
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
args->num_addressable_devices = num_addressable_devices;
return nullptr;
};
api->PJRT_LoadedExecutable_Delete =
Expand Down Expand Up @@ -77,23 +82,34 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int dev_0 = chip_ids[0];
tt::runtime::Device device = tt::runtime::openDevice({dev_0});

// Sanity check, as we only support execution on one chip currently.
assert(args->num_devices == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between num_devices and num_addressable_devices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_devices, passed as the argument of this function is the total number of devices for execution. It was inferred from num_addressable_devices in the xla bridge code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inferred how? Equal to it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, equal to it


int dev_index = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have only one process that executes, which is what this index represents.

tt::runtime::Binary binary(image_->get_binary());

std::vector<tt::runtime::Tensor> rt_inputs;
rt_inputs.reserve(args->num_args);

std::unordered_set<int> device_ids;

for (size_t i = 0; i < args->num_args; ++i) {
BufferInstance *buffer =
BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
int64_t buffer_device_id =
buffer->device().device_description()->device_id();
device_ids.insert(chip_ids[buffer_device_id]);
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}

assert(device_ids.size() == 1);
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

std::vector<int> device_ids_vector(device_ids.begin(), device_ids.end());

tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector);

std::vector<tt::runtime::Tensor> rt_outputs =
tt::runtime::submit(device, binary, 0, rt_inputs);
std::vector<tt::runtime::TensorDesc> output_specs =
Expand Down
23 changes: 18 additions & 5 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def is_initialized(self) -> bool:

return False

def connect_tt_device(self) -> jax.Device:
def connect_tt_device(self, device_num: int = 0) -> jax.Device:
"""Returns TTDevice handle."""
return self.connect_device(DeviceType.TT)
return self.connect_device(DeviceType.TT, device_num)

def connect_cpu(self) -> jax.Device:
"""Returns CPUDevice handle."""
Expand All @@ -81,9 +81,22 @@ def connect_gpu(self) -> jax.Device:
"""Returns GPUDevice handle."""
return self.connect_device(DeviceType.GPU)

def connect_device(self, device_type: DeviceType) -> jax.Device:
"""Returns handle for device identified by `device_type`."""
return jax.devices(device_type.value)[0]
def connect_device(
self, device_type: DeviceType, device_num: int = 0
) -> jax.Device:
"""
Returns handle for device identified by `device_type`.

If there are multiple available devices of `device_type`, `device_num` makes it
possible to choose between them. By default, returns first available device.
"""
assert device_num < self._number_of_devices(device_type)

return jax.devices(device_type.value)[device_num]

def _number_of_devices(self, device_type: DeviceType) -> int:
"""Returns the number of available devices of specified type."""
return len(jax.devices(device_type.value))

def _supported_devices(self) -> Sequence[DeviceType]:
"""Returns list of supported device types."""
Expand Down
22 changes: 13 additions & 9 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class DeviceRunner:
"""

@staticmethod
def run_on_tt_device(workload: Workload) -> Tensor:
def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_on_device(DeviceType.TT, workload)
return DeviceRunner._run_on_device(DeviceType.TT, workload, device_num)

@staticmethod
def run_on_cpu(workload: Workload) -> Tensor:
Expand All @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_on_tt_device(workload: Workload) -> Workload:
def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload:
"""Puts `workload` on TT device."""
return DeviceRunner._put_on_device(DeviceType.TT, workload)
return DeviceRunner._put_on_device(DeviceType.TT, workload, device_num)

@staticmethod
def put_on_cpu(workload: Workload) -> Workload:
Expand Down Expand Up @@ -64,18 +64,22 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
def _run_on_device(
device_type: DeviceType, workload: Workload, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)
device = device_connector.connect_device(device_type)
device_workload = DeviceRunner._put_on_device(device_type, workload, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
def _put_on_device(
device_type: DeviceType, workload: Workload, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type)
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
Expand Down
Loading