Skip to content

Commit

Permalink
[runtime] initial support for running model on device
Browse files Browse the repository at this point in the history
- moves device specific code from `tt_torch_device/` into `runtime/`
- adds `TTSystem` class (singleton) for holding all info
on present devices
- runs mlir compiler as a separate compile stage, which at the end
generates flatbuffer binary
- implements CompiledModel class for running inference on compiled
model
- `run_binary()` is the function which invokes tt-mlir runtime

NOTE: with this commit, the following tests are passing:
- pybuda/test/test_api.py
- pybuda/test/mlir/test_ops.py::test_add
- pybuda/test/mlir/test_ops.py::test_multiply
  • Loading branch information
pilkicTT committed Jul 23, 2024
1 parent c79f5ff commit 8fb96a7
Show file tree
Hide file tree
Showing 31 changed files with 686 additions and 389 deletions.
2 changes: 2 additions & 0 deletions pybuda/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_subdirectory(autograd)
add_subdirectory(shared_utils)
add_subdirectory(backend_api)
add_subdirectory(reportify)
add_subdirectory(runtime)
add_subdirectory(tt_torch_device)

### pybuda_csrc_objs ###
Expand Down Expand Up @@ -77,6 +78,7 @@ target_link_libraries(pybuda_csrc PRIVATE
backend_api
reportify
tt_torch_device
runtime
pybuda_csrc_objs
LLVM
MLIR
Expand Down
4 changes: 1 addition & 3 deletions pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ std::vector<std::pair<graphlib::NodeId, graphlib::NodeId>> run_post_autograd_gra
}

// ********** Run pre-lowering passes **********
graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph)
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
{
passes::print_graph(graph, "PRE_MLIR");
// Recalculate shapes, and figure out implicit broadcasts that are missing
Expand Down Expand Up @@ -227,8 +227,6 @@ graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph)
fold_tile_broadcast_ops_into_inputs(graph);
fold_tile_broadcast_ops_into_reduce(graph);

std::shared_ptr<void> binary = passes::run_mlir_compiler(graph);

return graph;
}

Expand Down
4 changes: 2 additions & 2 deletions pybuda/csrc/buda_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::unique_ptr<graphlib::Graph> run_pre_placer_buda_passes(
bool use_interactive_placer = true,
bool enable_device_tilize = false);

// Pre-lowering passes, last-minute changes before going to buda ops
graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph);
// Pre-lowering passes, last-minute changes before going to MLIR
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph);

}
3 changes: 2 additions & 1 deletion pybuda/csrc/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include pybuda/csrc/autograd/module.mk
include pybuda/csrc/reportify/module.mk
include pybuda/csrc/backend_api/module.mk
include pybuda/csrc/tt_torch_device/module.mk
include pybuda/csrc/runtime/module.mk

PYBUDA_CSRC_LDFLAGS = -Wl,-rpath,\$$ORIGIN/../python_env/lib/$(PYTHON_VERSION)/site-packages/torch/lib -ltorch -ltorch_cpu -lc10 -ltorch_python $(PYTHON_LDFLAGS) -l$(PYTHON_VERSION) $(MLIR_LIB_DIR) $(MLIR_LIBS) $(TT_MLIR_LIBS) $(RUNTIME_LIBS) -lm -lz -lcurses -lxml2 -lflatbuffers

Expand All @@ -44,7 +45,7 @@ PYBUDA_THIRD_PARTY_DEPS = $(SUBMODULESDIR)/third_party/pybind11.checkout

-include $(PYBUDA_CSRC_DEPS)

$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB)
$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB) $(PYBUDA_CSRC_RUNTIME_LIB)
@mkdir -p $(LIBDIR)
$(CXX) $(PYBUDA_CSRC_CFLAGS) $(CXXFLAGS) $(SHARED_LIB_FLAGS) -L$(TORCH_LIB_DIR) -o $@ $^ $(LDFLAGS) $(PYBUDA_CSRC_LDFLAGS)

Expand Down
5 changes: 5 additions & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class MLIRGenerator
auto opResult = builder_.create<mlir::tt::ttir::AddOp>(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes);
return opResult.getResult(0);
}
else if (op_node->op_name() == "multiply")
{
auto opResult = builder_.create<mlir::tt::ttir::MultiplyOp>(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes);
return opResult.getResult(0);
}
else {
log_error("Unsupported operation for lowering from PyBuda to TTIR: {}", op_node->op_name());
throw std::runtime_error("Unsupported operation for lowering from PyBuda to TTIR");
Expand Down
12 changes: 10 additions & 2 deletions pybuda/csrc/passes/mlir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
#include "mlir_compiler.hpp"
#include <memory>
#include "lower_to_mlir.hpp"
#include "mlir_passes.hpp"

Expand All @@ -17,15 +18,18 @@
#pragma clang diagnostic pop

// TTMLIR headers
#include "tt/runtime/types.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h"

#include "tt_torch_device/tt_device.hpp"

namespace tt::passes
{
/// Public API for lowering to MLIR, running MLIR passes and generate runtime binary.
std::shared_ptr<void> run_mlir_compiler(tt::graphlib::Graph *graph)
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph)
{
// Register all the required dialects.
mlir::DialectRegistry registry;
Expand All @@ -34,7 +38,7 @@ namespace tt::passes
mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::emitc::EmitCDialect>();
mlir::tensor::TensorDialect>();

// Create a context with all registered dialects.
mlir::MLIRContext context(registry);
Expand All @@ -43,17 +47,21 @@ namespace tt::passes

// Generate MLIR from the PyBuda graph.
mlir::OwningOpRef<mlir::ModuleOp> mlir_module = lower_to_mlir(graph, context);
tt::log_info("MLIR module generated successfully.");

// Run MLIR registered passes.
run_mlir_passes(mlir_module);
tt::log_info("MLIR passes run successfully.");

// Generate binary from the MLIR module.
auto binary = mlir::tt::ttnn::emitTTNNAsFlatbuffer(mlir_module);
tt::log_info("Flatbuffer binary generated successfully.");

if (binary == nullptr)
{
throw std::runtime_error("Failed to generate flatbuffer binary.");
}

return binary;
}
}
13 changes: 9 additions & 4 deletions pybuda/csrc/passes/mlir_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
#pragma once
#include <memory>

namespace tt::graphlib
#include "tt/runtime/types.h"

namespace tt
{
class Graph;
namespace graphlib
{
class Graph;
}
}

namespace tt::passes
{
/// Public API for running MLIR passes and generating binary.
std::shared_ptr<void> run_mlir_compiler(tt::graphlib::Graph *graph);
}
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph);
}
2 changes: 1 addition & 1 deletion pybuda/csrc/passes/mlir_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ namespace tt::passes

mlir_module.get().dump();
}
}
}
9 changes: 8 additions & 1 deletion pybuda/csrc/pybuda_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ namespace py = pybind11;
#include "passes/move_index_to_mm_weights.hpp"
#include "passes/passes_utils.hpp"
#include "passes/python_bindings.hpp"
#include "passes/mlir_compiler.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "runtime/python_bindings.hpp"
#include "shared_utils/sparse_matmul_utils.hpp"
#include "tt_torch_device/python_bindings.hpp"
#include "utils/ordered_associative_containers/ordered_map.hpp"
#include "utils/signal_handlers.hpp"

namespace tt {

PYBIND11_MODULE(_C, m) {
Expand Down Expand Up @@ -116,6 +119,9 @@ PYBIND11_MODULE(_C, m) {
py::module_ m_torch_device = m.def_submodule("torch_device", "TT Torch Device");
TorchDeviceModule(m_torch_device);

py::module m_runtime = m.def_submodule("runtime", "Submodule defining runtime functions");
RuntimeModule(m_runtime);

py::enum_<tt::MathFidelity>(m, "MathFidelity")
.value("LoFi", tt::MathFidelity::LoFi)
.value("HiFi2", tt::MathFidelity::HiFi2)
Expand Down Expand Up @@ -178,7 +184,8 @@ PYBIND11_MODULE(_C, m) {
py::arg("op_intermediates_to_save") = std::vector<std::string>{},
py::arg("use_interactive_placer") = true,
py::arg("enable_device_tilize") = false);
m.def("run_lower_to_mlir_passes", &run_lower_to_mlir_passes);
m.def("run_pre_lowering_passes", &run_pre_lowering_passes);
m.def("run_mlir_compiler", &passes::run_mlir_compiler);

m.def(
"dump_graph",
Expand Down
4 changes: 4 additions & 0 deletions pybuda/csrc/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
add_library(runtime STATIC runtime.cpp tt_device.cpp python_bindings.cpp)
add_dependencies(runtime build_tt_mlir)

target_compile_options(runtime PRIVATE ${STATIC_LIB_FLAGS} ${PYBUDA_CSRC_CFLAGS})
19 changes: 19 additions & 0 deletions pybuda/csrc/runtime/python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "runtime/python_bindings.hpp"
#include "runtime/runtime.hpp"
#include "tt/runtime/types.h"

namespace tt {

void RuntimeModule(py::module &m_runtime)
{
py::class_<runtime::Binary>(m_runtime, "Binary")
.def("get_program_inputs", &runtime::Binary::getProgramInputs)
.def("get_program_outputs", &runtime::Binary::getProgramOutputs);
m_runtime.def("run_binary", tt::run_binary);
}

} // namespace tt
19 changes: 19 additions & 0 deletions pybuda/csrc/runtime/python_bindings.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#include "pybind11/pybind11.h"
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#pragma clang diagnostic pop
namespace py = pybind11;

namespace tt {

void RuntimeModule(py::module &m_runtime);

} // namespace tt
142 changes: 142 additions & 0 deletions pybuda/csrc/runtime/runtime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "runtime.hpp"
#include <optional>

#include "tt_device.hpp"
#include "utils/logger.hpp"
#include "tt/runtime/runtime.h"

namespace tt {

static target::DataType torch_scalar_type_to_dt(torch::ScalarType st)
{
switch (st)
{
case torch::ScalarType::Byte: return target::DataType::UInt8;
case torch::ScalarType::Char: return target::DataType::UInt8;
case torch::ScalarType::Short: return target::DataType::UInt16;
case torch::ScalarType::Int: return target::DataType::UInt32;
case torch::ScalarType::Long: return target::DataType::UInt32;
case torch::ScalarType::Half: return target::DataType::Float16;
case torch::ScalarType::Float: return target::DataType::Float32;
// case torch::ScalarType::Double:
// case torch::ScalarType::ComplexHalf:
// case torch::ScalarType::ComplexFloat:
// case torch::ScalarType::ComplexDouble:
// case torch::ScalarType::Bool:
case torch::ScalarType::BFloat16: return target::DataType::BFloat16;
default: break;
}

log_fatal(LogTTDevice, "Unhandled dtype {}", st);
}

static torch::ScalarType dt_to_torch_scalar_type(target::DataType df)
{
switch (df)
{
case target::DataType::UInt8: return torch::ScalarType::Byte;
case target::DataType::UInt16: return torch::ScalarType::Short;
case target::DataType::UInt32: return torch::ScalarType::Int;
case target::DataType::Float16: return torch::ScalarType::Half;
case target::DataType::Float32: return torch::ScalarType::Float;
case target::DataType::BFloat16: return torch::ScalarType::BFloat16;
default: break;
}

log_fatal(LogTTDevice, "Unhandled dtype {}", df);
}

template <typename T>
std::vector<int64_t> as_vec_int64(std::vector<T> const& vec)
{
std::vector<int64_t> result;
result.reserve(vec.size());
for (auto const& v : vec)
{
result.push_back(v);
}
return result;
}

static runtime::Tensor create_tensor(const torch::Tensor& tensor)
{
auto data = std::shared_ptr<void>(
tensor.data_ptr(),
[tensor](void*) { (void)tensor; } // Capture tensor by value to increase ref count and keep it alive
);

auto shape = std::vector<uint32_t>(tensor.sizes().begin(), tensor.sizes().end());
auto stride = std::vector<uint32_t>(tensor.strides().begin(), tensor.strides().end());

return runtime::createTensor(
data,
shape,
stride,
tensor.element_size(),
torch_scalar_type_to_dt(tensor.scalar_type()));
}

runtime::Binary load_binary_from_file(std::string const& filename)
{
runtime::Binary binary = tt::runtime::Binary::loadFromPath(filename.c_str()).handle;
return binary;
}

std::vector<torch::Tensor> run_binary_from_file(std::string const& filename, int program_idx, std::vector<torch::Tensor> const& inputs)
{
auto binary = load_binary_from_file(filename);

return run_binary(binary, program_idx, inputs);
}

std::vector<torch::Tensor> run_binary(runtime::Binary &binary, int program_idx, std::vector<torch::Tensor> const& inputs)
{
auto& system = TTSystem::get_system();

for (auto &device : system.devices)
{
if (!device->is_open())
{
device->open_device();
}
}

// For now, we only support a single device.
auto& tt_device = system.devices[0];
if (!tt_device->is_open())
{
log_fatal(LogTTDevice, "Failed to open device");
}

auto& device = *tt_device->rt_device;

std::vector<runtime::Tensor> rt_inputs;
for (auto const& input : inputs)
{
rt_inputs.emplace_back(create_tensor(input));
}

std::vector<torch::Tensor> outputs;
std::vector<runtime::Tensor> rt_outputs;
std::vector<runtime::TensorDesc> output_descs = binary.getProgramOutputs(program_idx);
outputs.reserve(output_descs.size());
for (auto const& desc : output_descs)
{
std::vector<std::int64_t> shape = as_vec_int64(desc.shape);
std::vector<std::int64_t> stride = as_vec_int64(desc.stride);

torch::Tensor output = at::empty_strided(shape, stride, dt_to_torch_scalar_type(desc.dataType));
outputs.emplace_back(std::move(output));
rt_outputs.emplace_back(create_tensor(outputs.back()));
}

runtime::Event _ = runtime::submit(device, binary, program_idx, rt_inputs, rt_outputs);

return outputs;
}

} // namespace tt
Loading

0 comments on commit 8fb96a7

Please sign in to comment.