Skip to content

Commit

Permalink
Merge pull request #8 from tenstorrent/pilkic/device-runtime
Browse files Browse the repository at this point in the history
[runtime] initial support for running model on device
  • Loading branch information
pilkicTT authored Jul 24, 2024
2 parents c79f5ff + 35b0f5c commit e05b7e2
Show file tree
Hide file tree
Showing 31 changed files with 731 additions and 401 deletions.
11 changes: 9 additions & 2 deletions pybuda/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_subdirectory(autograd)
add_subdirectory(shared_utils)
add_subdirectory(backend_api)
add_subdirectory(reportify)
add_subdirectory(runtime)
add_subdirectory(tt_torch_device)

### pybuda_csrc_objs ###
Expand Down Expand Up @@ -77,17 +78,23 @@ target_link_libraries(pybuda_csrc PRIVATE
backend_api
reportify
tt_torch_device
runtime
pybuda_csrc_objs

# NOTE: ordering of the libraries will affect the linking
LLVM
MLIR
TTNNTargetFlatbuffer
MLIRTTDialect
MLIRTTIRDialect
MLIRTTNNDialect
MLIRTTIRTransforms
MLIRTTNNTransforms
MLIRTTKernelDialect
MLIRTTMetalDialect
MLIRTTIRTransforms
MLIRTTNNTransforms
MLIRTTIRAnalysis
MLIRTTNNPipelines
TTMLIRTTNNToEmitC
TTRuntime
TTRuntimeTTNN
tt_metal
Expand Down
4 changes: 1 addition & 3 deletions pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ std::vector<std::pair<graphlib::NodeId, graphlib::NodeId>> run_post_autograd_gra
}

// ********** Run pre-lowering passes **********
graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph)
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
{
passes::print_graph(graph, "PRE_MLIR");
// Recalculate shapes, and figure out implicit broadcasts that are missing
Expand Down Expand Up @@ -227,8 +227,6 @@ graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph)
fold_tile_broadcast_ops_into_inputs(graph);
fold_tile_broadcast_ops_into_reduce(graph);

std::shared_ptr<void> binary = passes::run_mlir_compiler(graph);

return graph;
}

Expand Down
4 changes: 2 additions & 2 deletions pybuda/csrc/buda_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::unique_ptr<graphlib::Graph> run_pre_placer_buda_passes(
bool use_interactive_placer = true,
bool enable_device_tilize = false);

// Pre-lowering passes, last-minute changes before going to buda ops
graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph);
// Pre-lowering passes, last-minute changes before going to MLIR
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph);

}
3 changes: 2 additions & 1 deletion pybuda/csrc/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include pybuda/csrc/autograd/module.mk
include pybuda/csrc/reportify/module.mk
include pybuda/csrc/backend_api/module.mk
include pybuda/csrc/tt_torch_device/module.mk
include pybuda/csrc/runtime/module.mk

PYBUDA_CSRC_LDFLAGS = -Wl,-rpath,\$$ORIGIN/../python_env/lib/$(PYTHON_VERSION)/site-packages/torch/lib -ltorch -ltorch_cpu -lc10 -ltorch_python $(PYTHON_LDFLAGS) -l$(PYTHON_VERSION) $(MLIR_LIB_DIR) $(MLIR_LIBS) $(TT_MLIR_LIBS) $(RUNTIME_LIBS) -lm -lz -lcurses -lxml2 -lflatbuffers

Expand All @@ -44,7 +45,7 @@ PYBUDA_THIRD_PARTY_DEPS = $(SUBMODULESDIR)/third_party/pybind11.checkout

-include $(PYBUDA_CSRC_DEPS)

$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB)
$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB) $(PYBUDA_CSRC_RUNTIME_LIB)
@mkdir -p $(LIBDIR)
$(CXX) $(PYBUDA_CSRC_CFLAGS) $(CXXFLAGS) $(SHARED_LIB_FLAGS) -L$(TORCH_LIB_DIR) -o $@ $^ $(LDFLAGS) $(PYBUDA_CSRC_LDFLAGS)

Expand Down
5 changes: 5 additions & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class MLIRGenerator
auto opResult = builder_.create<mlir::tt::ttir::AddOp>(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes);
return opResult.getResult(0);
}
else if (op_node->op_name() == "multiply")
{
auto opResult = builder_.create<mlir::tt::ttir::MultiplyOp>(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes);
return opResult.getResult(0);
}
else {
log_error("Unsupported operation for lowering from PyBuda to TTIR: {}", op_node->op_name());
throw std::runtime_error("Unsupported operation for lowering from PyBuda to TTIR");
Expand Down
16 changes: 12 additions & 4 deletions pybuda/csrc/passes/mlir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
#include "mlir_compiler.hpp"
#include <memory>
#include "lower_to_mlir.hpp"
#include "mlir_passes.hpp"

Expand All @@ -17,15 +18,18 @@
#pragma clang diagnostic pop

// TTMLIR headers
#include "tt/runtime/types.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"

#include "tt_torch_device/tt_device.hpp"

namespace tt::passes
{
/// Public API for lowering to MLIR, running MLIR passes and generate runtime binary.
std::shared_ptr<void> run_mlir_compiler(tt::graphlib::Graph *graph)
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph)
{
// Register all the required dialects.
mlir::DialectRegistry registry;
Expand All @@ -34,7 +38,7 @@ namespace tt::passes
mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::emitc::EmitCDialect>();
mlir::tensor::TensorDialect>();

// Create a context with all registered dialects.
mlir::MLIRContext context(registry);
Expand All @@ -43,17 +47,21 @@ namespace tt::passes

// Generate MLIR from the PyBuda graph.
mlir::OwningOpRef<mlir::ModuleOp> mlir_module = lower_to_mlir(graph, context);
tt::log_info("MLIR module generated successfully.");

// Run MLIR registered passes.
run_mlir_passes(mlir_module);
tt::log_info("MLIR passes run successfully.");

// Generate binary from the MLIR module.
auto binary = mlir::tt::ttnn::emitTTNNAsFlatbuffer(mlir_module);
auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());
tt::log_info("Flatbuffer binary generated successfully.");

if (binary == nullptr)
{
throw std::runtime_error("Failed to generate flatbuffer binary.");
}

return binary;
}
}
13 changes: 9 additions & 4 deletions pybuda/csrc/passes/mlir_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
#pragma once
#include <memory>

namespace tt::graphlib
#include "tt/runtime/types.h"

namespace tt
{
class Graph;
namespace graphlib
{
class Graph;
}
}

namespace tt::passes
{
/// Public API for running MLIR passes and generating binary.
std::shared_ptr<void> run_mlir_compiler(tt::graphlib::Graph *graph);
}
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph);
}
46 changes: 37 additions & 9 deletions pybuda/csrc/passes/mlir_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,52 @@
#include "mlir/IR/BuiltinOps.h"

// TTMLIR headers
#include "ttmlir/Dialect/TTIR/Passes.h"
#include "ttmlir/Dialect/TTNN/Passes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/Passes.h"
#include "utils/logger.hpp"

namespace tt::passes
{
/// Public API for running MLIR passes and generating binary.
void run_mlir_passes(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module)
{
// Register required passes
mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();
static bool _ = []() {
// Register required passes
mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

// Register pass pipelines
// This will internally register the pipelines in the MLIR pipeline registry. Then,
// the registry can be used to lookup the pipeline by its name and add it to the pass manager.
mlir::tt::ttnn::registerTTNNPipelines();

return true;
}();
(void)_;

// Create a pass manager.
mlir::PassManager pm(mlir_module.get()->getName());

// Create a pass pipeline
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm);
// Get the pipeline info for the wanted pipeline.
const auto pipelineInfo = mlir::PassPipelineInfo::lookup("ttir-to-ttnn-backend-pipeline");

// This error handler is necessary when adding the pipeline to the pass manager (via PassPipelineInfo).
// It's supposed to be called when there's an error during parsing of the pipeline options.
// However, I think it's wrongly implemented in the MLIR library, so it doesn't get called.
mlir::function_ref<mlir::LogicalResult(const mlir::Twine &)> err_handler = [](const mlir::Twine &location) {
log_error(LogMLIRGenerator, "Error during parsing pipeline options: {}", location.str());
return mlir::failure();
};

// Pipeline options are empty for now.
std::string options{""};

auto result = pipelineInfo->addToPipeline(pm, options, err_handler);
if (mlir::failed(result))
{
throw std::runtime_error("Failed to add the pipeline to the pass manager!");
}

// Run the pass manager.
if (mlir::failed(pm.run(mlir_module.get())))
Expand All @@ -37,4 +65,4 @@ namespace tt::passes

mlir_module.get().dump();
}
}
}
9 changes: 8 additions & 1 deletion pybuda/csrc/pybuda_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ namespace py = pybind11;
#include "passes/move_index_to_mm_weights.hpp"
#include "passes/passes_utils.hpp"
#include "passes/python_bindings.hpp"
#include "passes/mlir_compiler.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "runtime/python_bindings.hpp"
#include "shared_utils/sparse_matmul_utils.hpp"
#include "tt_torch_device/python_bindings.hpp"
#include "utils/ordered_associative_containers/ordered_map.hpp"
#include "utils/signal_handlers.hpp"

namespace tt {

PYBIND11_MODULE(_C, m) {
Expand Down Expand Up @@ -116,6 +119,9 @@ PYBIND11_MODULE(_C, m) {
py::module_ m_torch_device = m.def_submodule("torch_device", "TT Torch Device");
TorchDeviceModule(m_torch_device);

py::module m_runtime = m.def_submodule("runtime", "Submodule defining runtime functions");
RuntimeModule(m_runtime);

py::enum_<tt::MathFidelity>(m, "MathFidelity")
.value("LoFi", tt::MathFidelity::LoFi)
.value("HiFi2", tt::MathFidelity::HiFi2)
Expand Down Expand Up @@ -178,7 +184,8 @@ PYBIND11_MODULE(_C, m) {
py::arg("op_intermediates_to_save") = std::vector<std::string>{},
py::arg("use_interactive_placer") = true,
py::arg("enable_device_tilize") = false);
m.def("run_lower_to_mlir_passes", &run_lower_to_mlir_passes);
m.def("run_pre_lowering_passes", &run_pre_lowering_passes);
m.def("run_mlir_compiler", &passes::run_mlir_compiler);

m.def(
"dump_graph",
Expand Down
4 changes: 4 additions & 0 deletions pybuda/csrc/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
add_library(runtime STATIC runtime.cpp tt_device.cpp python_bindings.cpp)
add_dependencies(runtime build_tt_mlir)

target_compile_options(runtime PRIVATE ${STATIC_LIB_FLAGS} ${PYBUDA_CSRC_CFLAGS})
19 changes: 19 additions & 0 deletions pybuda/csrc/runtime/python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "runtime/python_bindings.hpp"
#include "runtime/runtime.hpp"
#include "tt/runtime/types.h"

namespace tt {

void RuntimeModule(py::module &m_runtime)
{
py::class_<runtime::Binary>(m_runtime, "Binary")
.def("get_program_inputs", &runtime::Binary::getProgramInputs)
.def("get_program_outputs", &runtime::Binary::getProgramOutputs);
m_runtime.def("run_binary", tt::run_binary);
}

} // namespace tt
19 changes: 19 additions & 0 deletions pybuda/csrc/runtime/python_bindings.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#include "pybind11/pybind11.h"
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#pragma clang diagnostic pop
namespace py = pybind11;

namespace tt {

void RuntimeModule(py::module &m_runtime);

} // namespace tt
Loading

0 comments on commit e05b7e2

Please sign in to comment.