Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Jan 10, 2025
1 parent 418f9d4 commit 112e32a
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 31 deletions.
1 change: 0 additions & 1 deletion inc/common/pjrt_implementation/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class DeviceDescription {
return user_string_;
}

// TODO
int64_t device_id() { return device_id_; }

int client_id() { return client_id_; }
Expand Down
2 changes: 1 addition & 1 deletion inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ExecutableImage {

const std::string &get_code() const { return code; }

const size_t get_num_addresible_devices() const {
const size_t get_num_addressable_devices() const {
return num_addressable_devices;
}

Expand Down
3 changes: 0 additions & 3 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ bool ModuleBuilder::isScalarType(mlir::Type type) {
return false;
}

// Currently hardcoded to one, as we only support one-chip execution.
size_t ModuleBuilder::getNumberOfAddressibleDevices() const { return 1; }

void ModuleBuilder::convertFromSHLOToTTIR(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
// Implicit nesting required to call the stablehlo.composite --> func.call
Expand Down
4 changes: 2 additions & 2 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class ModuleBuilder {
bool isOutputScalar(size_t index) const;

// This needs to return the number of addressable devices from the StableHLO
// code.
size_t getNumberOfAddressibleDevices() const;
// code. Currently hardcoded to one, as we only support one-chip execution.
size_t getNumAddressableDevices() const { return 1; }

private:
// Creates VHLO module from the input program code.
Expand Down
2 changes: 1 addition & 1 deletion src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
std::string(program->code, program->code_size),
module_builder_->getNumInputs(),
module_builder_->getNumOutputs(),
module_builder_->getNumberOfAddressibleDevices()),
module_builder_->getNumAddressableDevices()),
addressable_devices_);
*out_executable = executable.release();

Expand Down
9 changes: 5 additions & 4 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) {
DLOG_F(
LOG_DEBUG,
"LoadedExecutableInstance::PJRT_LoadedExecutable_AddressableDevices");
auto &addressable_devices =
const std::vector<DeviceInstance *> &addressable_devices =
LoadedExecutableInstance::Unwrap(args->executable)
->addressable_devices();
int num_addressable_devices =
LoadedExecutableInstance::Unwrap(args->executable)
->image_->get_num_addresible_devices();
->image_->get_num_addressable_devices();
args->addressable_devices = const_cast<PJRT_Device **>(
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
args->num_addressable_devices = num_addressable_devices;
Expand Down Expand Up @@ -96,8 +96,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
BufferInstance *buffer =
BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
device_ids.insert(
chip_ids[buffer->device().device_description()->device_id()]);
int64_t buffer_device_id =
buffer->device().device_description()->device_id();
device_ids.insert(chip_ids[buffer_device_id]);
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}

Expand Down
26 changes: 16 additions & 10 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def is_initialized(self) -> bool:

return False

def connect_tt_device(self, num_device: int = 0) -> jax.Device:
def connect_tt_device(self, device_num: int = 0) -> jax.Device:
"""Returns TTDevice handle."""
return self.connect_device(DeviceType.TT, num_device)
return self.connect_device(DeviceType.TT, device_num)

def connect_cpu(self) -> jax.Device:
"""Returns CPUDevice handle."""
Expand All @@ -81,16 +81,22 @@ def connect_gpu(self) -> jax.Device:
"""Returns GPUDevice handle."""
return self.connect_device(DeviceType.GPU)

def _number_of_devices(self, device_type: DeviceType) -> int:
"""Returns the number of devices of specifed type."""
return len(jax.devices(device_type.value))

def connect_device(
self, device_type: DeviceType, num_device: int = 0
self, device_type: DeviceType, device_num: int = 0
) -> jax.Device:
"""Returns handle for device identified by `device_type`."""
assert num_device < self._number_of_devices(device_type)
return jax.devices(device_type.value)[num_device]
"""
Returns handle for device identified by `device_type`.
If there are multiple available devices of `device_type`, `device_num` makes it
possible to choose between them. By default, returns first available device.
"""
assert device_num < self._number_of_devices(device_type)

return jax.devices(device_type.value)[device_num]

def _number_of_devices(self, device_type: DeviceType) -> int:
"""Returns the number of available devices of specified type."""
return len(jax.devices(device_type.value))

def _supported_devices(self) -> Sequence[DeviceType]:
"""Returns list of supported device types."""
Expand Down
18 changes: 9 additions & 9 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class DeviceRunner:
"""

@staticmethod
def run_on_tt_device(workload: Workload, num_device: int = 0) -> Tensor:
def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_on_device(DeviceType.TT, workload, num_device)
return DeviceRunner._run_on_device(DeviceType.TT, workload, device_num)

@staticmethod
def run_on_cpu(workload: Workload) -> Tensor:
Expand All @@ -34,9 +34,9 @@ def run_on_gpu(workload: Workload) -> Tensor:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_on_tt_device(workload: Workload, num_device: int = 0) -> Workload:
def put_on_tt_device(workload: Workload, device_num: int = 0) -> Workload:
"""Puts `workload` on TT device."""
return DeviceRunner._put_on_device(DeviceType.TT, workload, num_device)
return DeviceRunner._put_on_device(DeviceType.TT, workload, device_num)

@staticmethod
def put_on_cpu(workload: Workload) -> Workload:
Expand Down Expand Up @@ -65,21 +65,21 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:

@staticmethod
def _run_on_device(
device_type: DeviceType, workload: Workload, num_device: int = 0
device_type: DeviceType, workload: Workload, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload, num_device)
device = device_connector.connect_device(device_type)
device_workload = DeviceRunner._put_on_device(device_type, workload, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(
device_type: DeviceType, workload: Workload, num_device: int = 0
device_type: DeviceType, workload: Workload, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type, num_device)
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
Expand Down

0 comments on commit 112e32a

Please sign in to comment.