From 35b0f5c7cd0e104f1a77f540180b8966dd3c211e Mon Sep 17 00:00:00 2001 From: Predrag Ilkic Date: Fri, 12 Jul 2024 16:19:40 +0200 Subject: [PATCH] [runtime] initial support for running model on device - moves device specific code from `tt_torch_device/` into `runtime/` - adds `TTSystem` class (singleton) for holding all info on present devices - runs mlir compiler as a separate compile stage, which at the end generates flatbuffer binary - implements `CompiledModel` class for running inference on compiled model - `run_binary()` is the function which invokes tt-mlir runtime NOTE: with this commit, the following tests are passing: - pybuda/test/test_api.py - pybuda/test/mlir/test_ops.py::test_add - pybuda/test/mlir/test_ops.py::test_multiply --- pybuda/csrc/CMakeLists.txt | 11 +- pybuda/csrc/buda_passes.cpp | 4 +- pybuda/csrc/buda_passes.hpp | 4 +- pybuda/csrc/module.mk | 3 +- pybuda/csrc/passes/lower_to_mlir.cpp | 5 + pybuda/csrc/passes/mlir_compiler.cpp | 16 +- pybuda/csrc/passes/mlir_compiler.hpp | 13 +- pybuda/csrc/passes/mlir_passes.cpp | 46 ++++- pybuda/csrc/pybuda_bindings.cpp | 9 +- pybuda/csrc/runtime/CMakeLists.txt | 4 + pybuda/csrc/runtime/python_bindings.cpp | 19 +++ pybuda/csrc/runtime/python_bindings.hpp | 19 +++ pybuda/csrc/runtime/runtime.cpp | 142 +++++++++++++++ pybuda/csrc/runtime/runtime.hpp | 19 +++ pybuda/csrc/runtime/tt_device.cpp | 64 +++++++ pybuda/csrc/runtime/tt_device.hpp | 83 +++++++++ .../csrc/tt_torch_device/python_bindings.cpp | 4 - .../tt_torch_device/torch_device_impl.cpp | 47 ++--- pybuda/csrc/tt_torch_device/tt_device.cpp | 48 +----- pybuda/csrc/tt_torch_device/tt_device.hpp | 44 +---- pybuda/pybuda/_C/__init__.pyi | 161 +++++++----------- pybuda/pybuda/_C/autograd.pyi | 10 +- pybuda/pybuda/_C/graph.pyi | 99 ++++++----- pybuda/pybuda/_C/torch_device.pyi | 43 ++--- pybuda/pybuda/compile.py | 90 +++++----- pybuda/pybuda/compiled_graph_state.py | 66 +++++-- pybuda/pybuda/config.py | 5 +- pybuda/test/mlir/test_ops.py | 15 +- pybuda/test/test_api.py | 33 +++- third_party/tt-mlir | 2 +- utils/signal_handlers.hpp | 4 +- 31 files changed, 731 insertions(+), 401 deletions(-) create mode 100644 pybuda/csrc/runtime/CMakeLists.txt create mode 100644 pybuda/csrc/runtime/python_bindings.cpp create mode 100644 pybuda/csrc/runtime/python_bindings.hpp create mode 100644 pybuda/csrc/runtime/runtime.cpp create mode 100644 pybuda/csrc/runtime/runtime.hpp create mode 100644 pybuda/csrc/runtime/tt_device.cpp create mode 100644 pybuda/csrc/runtime/tt_device.hpp diff --git a/pybuda/csrc/CMakeLists.txt b/pybuda/csrc/CMakeLists.txt index 7a1859944..1ee9d983e 100644 --- a/pybuda/csrc/CMakeLists.txt +++ b/pybuda/csrc/CMakeLists.txt @@ -38,6 +38,7 @@ add_subdirectory(autograd) add_subdirectory(shared_utils) add_subdirectory(backend_api) add_subdirectory(reportify) +add_subdirectory(runtime) add_subdirectory(tt_torch_device) ### pybuda_csrc_objs ### @@ -77,17 +78,23 @@ target_link_libraries(pybuda_csrc PRIVATE backend_api reportify tt_torch_device + runtime pybuda_csrc_objs + + # NOTE: ordering of the libraries will affect the linking LLVM MLIR + TTNNTargetFlatbuffer MLIRTTDialect MLIRTTIRDialect MLIRTTNNDialect - MLIRTTIRTransforms - MLIRTTNNTransforms MLIRTTKernelDialect MLIRTTMetalDialect + MLIRTTIRTransforms + MLIRTTNNTransforms MLIRTTIRAnalysis + MLIRTTNNPipelines + TTMLIRTTNNToEmitC TTRuntime TTRuntimeTTNN tt_metal diff --git a/pybuda/csrc/buda_passes.cpp b/pybuda/csrc/buda_passes.cpp index 93a3fde19..4c5dddf0c 100644 --- a/pybuda/csrc/buda_passes.cpp +++ b/pybuda/csrc/buda_passes.cpp @@ -193,7 +193,7 @@ std::vector> run_post_autograd_gra } // ********** Run pre-lowering passes ********** -graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph) +graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph) { passes::print_graph(graph, "PRE_MLIR"); // Recalculate shapes, and figure out implicit broadcasts that are missing @@ -227,8 +227,6 @@ graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph) fold_tile_broadcast_ops_into_inputs(graph); fold_tile_broadcast_ops_into_reduce(graph); - std::shared_ptr binary = passes::run_mlir_compiler(graph); - return graph; } diff --git a/pybuda/csrc/buda_passes.hpp b/pybuda/csrc/buda_passes.hpp index 9f995dad7..819769d75 100644 --- a/pybuda/csrc/buda_passes.hpp +++ b/pybuda/csrc/buda_passes.hpp @@ -49,7 +49,7 @@ std::unique_ptr run_pre_placer_buda_passes( bool use_interactive_placer = true, bool enable_device_tilize = false); -// Pre-lowering passes, last-minute changes before going to buda ops -graphlib::Graph* run_lower_to_mlir_passes(graphlib::Graph *graph); +// Pre-lowering passes, last-minute changes before going to MLIR +graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph); } diff --git a/pybuda/csrc/module.mk b/pybuda/csrc/module.mk index b1d8d1004..4739f2d51 100644 --- a/pybuda/csrc/module.mk +++ b/pybuda/csrc/module.mk @@ -28,6 +28,7 @@ include pybuda/csrc/autograd/module.mk include pybuda/csrc/reportify/module.mk include pybuda/csrc/backend_api/module.mk include pybuda/csrc/tt_torch_device/module.mk +include pybuda/csrc/runtime/module.mk PYBUDA_CSRC_LDFLAGS = -Wl,-rpath,\$$ORIGIN/../python_env/lib/$(PYTHON_VERSION)/site-packages/torch/lib -ltorch -ltorch_cpu -lc10 -ltorch_python $(PYTHON_LDFLAGS) -l$(PYTHON_VERSION) $(MLIR_LIB_DIR) $(MLIR_LIBS) $(TT_MLIR_LIBS) $(RUNTIME_LIBS) -lm -lz -lcurses -lxml2 -lflatbuffers @@ -44,7 +45,7 @@ PYBUDA_THIRD_PARTY_DEPS = $(SUBMODULESDIR)/third_party/pybind11.checkout -include $(PYBUDA_CSRC_DEPS) -$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB) +$(PYBUDA_CSRC_LIB): $(PYBUDA_CSRC_OBJS) $(PYBUDA_CSRC_GRAPH_LIB) $(PYBUDA_CSRC_AUTOGRAD) $(PYBUDA_CSRC_PATTERN_MATCHER_LIB) $(PYBUDA_CSRC_BALANCER_LIB) $(PYBUDA_CSRC_PLACER_LIB) $(PYBUDA_CSRC_SCHEDULER_LIB) $(PYBUDA_CSRC_REPORTIFY) $(PYBUDA_CSRC_BACKENDAPI_LIB) $(PYBUDA_CSRC_SHARED_UTILS_LIB) $(PYBUDA_CSRC_PERF_MODEL_LIB) $(PYBUDA_CSRC_TT_TORCH_DEVICE_LIB) $(PYBUDA_CSRC_RUNTIME_LIB) @mkdir -p $(LIBDIR) $(CXX) $(PYBUDA_CSRC_CFLAGS) $(CXXFLAGS) $(SHARED_LIB_FLAGS) -L$(TORCH_LIB_DIR) -o $@ $^ $(LDFLAGS) $(PYBUDA_CSRC_LDFLAGS) diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index cecbb71b3..5317ebeaa 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -210,6 +210,11 @@ class MLIRGenerator auto opResult = builder_.create(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes); return opResult.getResult(0); } + else if (op_node->op_name() == "multiply") + { + auto opResult = builder_.create(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes); + return opResult.getResult(0); + } else { log_error("Unsupported operation for lowering from PyBuda to TTIR: {}", op_node->op_name()); throw std::runtime_error("Unsupported operation for lowering from PyBuda to TTIR"); diff --git a/pybuda/csrc/passes/mlir_compiler.cpp b/pybuda/csrc/passes/mlir_compiler.cpp index 2269f6205..bda168640 100644 --- a/pybuda/csrc/passes/mlir_compiler.cpp +++ b/pybuda/csrc/passes/mlir_compiler.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 #include "mlir_compiler.hpp" +#include #include "lower_to_mlir.hpp" #include "mlir_passes.hpp" @@ -17,15 +18,18 @@ #pragma clang diagnostic pop // TTMLIR headers +#include "tt/runtime/types.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" -#include "ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h" +#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" + +#include "tt_torch_device/tt_device.hpp" namespace tt::passes { /// Public API for lowering to MLIR, running MLIR passes and generate runtime binary. - std::shared_ptr run_mlir_compiler(tt::graphlib::Graph *graph) + runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph) { // Register all the required dialects. mlir::DialectRegistry registry; @@ -34,7 +38,7 @@ namespace tt::passes mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect, mlir::tt::ttnn::TTNNDialect, mlir::arith::ArithDialect, mlir::func::FuncDialect, mlir::ml_program::MLProgramDialect, - mlir::tensor::TensorDialect, mlir::emitc::EmitCDialect>(); + mlir::tensor::TensorDialect>(); // Create a context with all registered dialects. mlir::MLIRContext context(registry); @@ -43,17 +47,21 @@ namespace tt::passes // Generate MLIR from the PyBuda graph. mlir::OwningOpRef mlir_module = lower_to_mlir(graph, context); + tt::log_info("MLIR module generated successfully."); // Run MLIR registered passes. run_mlir_passes(mlir_module); + tt::log_info("MLIR passes run successfully."); // Generate binary from the MLIR module. - auto binary = mlir::tt::ttnn::emitTTNNAsFlatbuffer(mlir_module); + auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); + tt::log_info("Flatbuffer binary generated successfully."); if (binary == nullptr) { throw std::runtime_error("Failed to generate flatbuffer binary."); } + return binary; } } diff --git a/pybuda/csrc/passes/mlir_compiler.hpp b/pybuda/csrc/passes/mlir_compiler.hpp index 345117f66..eed44b24a 100644 --- a/pybuda/csrc/passes/mlir_compiler.hpp +++ b/pybuda/csrc/passes/mlir_compiler.hpp @@ -4,13 +4,18 @@ #pragma once #include -namespace tt::graphlib +#include "tt/runtime/types.h" + +namespace tt { - class Graph; + namespace graphlib + { + class Graph; + } } namespace tt::passes { /// Public API for running MLIR passes and generating binary. - std::shared_ptr run_mlir_compiler(tt::graphlib::Graph *graph); -} \ No newline at end of file + runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph); +} diff --git a/pybuda/csrc/passes/mlir_passes.cpp b/pybuda/csrc/passes/mlir_passes.cpp index 6890df8c9..851ef3703 100644 --- a/pybuda/csrc/passes/mlir_passes.cpp +++ b/pybuda/csrc/passes/mlir_passes.cpp @@ -10,24 +10,52 @@ #include "mlir/IR/BuiltinOps.h" // TTMLIR headers -#include "ttmlir/Dialect/TTIR/Passes.h" -#include "ttmlir/Dialect/TTNN/Passes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Pipelines/Passes.h" +#include "utils/logger.hpp" namespace tt::passes { /// Public API for running MLIR passes and generating binary. void run_mlir_passes(mlir::OwningOpRef &mlir_module) { - // Register required passes - mlir::tt::ttir::registerPasses(); - mlir::tt::ttnn::registerPasses(); + static bool _ = []() { + // Register required passes + mlir::tt::ttir::registerPasses(); + mlir::tt::ttnn::registerPasses(); + + // Register pass pipelines + // This will internally register the pipelines in the MLIR pipeline registry. Then, + // the registry can be used to lookup the pipeline by its name and add it to the pass manager. + mlir::tt::ttnn::registerTTNNPipelines(); + + return true; + }(); + (void)_; // Create a pass manager. mlir::PassManager pm(mlir_module.get()->getName()); - // Create a pass pipeline - mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm); + // Get the pipeline info for the wanted pipeline. + const auto pipelineInfo = mlir::PassPipelineInfo::lookup("ttir-to-ttnn-backend-pipeline"); + + // This error handler is necessary when adding the pipeline to the pass manager (via PassPipelineInfo). + // It's supposed to be called when there's an error during parsing of the pipeline options. + // However, I think it's wrongly implemented in the MLIR library, so it doesn't get called. + mlir::function_ref err_handler = [](const mlir::Twine &location) { + log_error(LogMLIRGenerator, "Error during parsing pipeline options: {}", location.str()); + return mlir::failure(); + }; + + // Pipeline options are empty for now. + std::string options{""}; + + auto result = pipelineInfo->addToPipeline(pm, options, err_handler); + if (mlir::failed(result)) + { + throw std::runtime_error("Failed to add the pipeline to the pass manager!"); + } // Run the pass manager. if (mlir::failed(pm.run(mlir_module.get()))) @@ -37,4 +65,4 @@ namespace tt::passes mlir_module.get().dump(); } -} \ No newline at end of file +} diff --git a/pybuda/csrc/pybuda_bindings.cpp b/pybuda/csrc/pybuda_bindings.cpp index 0b6b88f0e..1e5718e59 100644 --- a/pybuda/csrc/pybuda_bindings.cpp +++ b/pybuda/csrc/pybuda_bindings.cpp @@ -26,12 +26,15 @@ namespace py = pybind11; #include "passes/move_index_to_mm_weights.hpp" #include "passes/passes_utils.hpp" #include "passes/python_bindings.hpp" +#include "passes/mlir_compiler.hpp" #include "python_bindings_common.hpp" #include "reportify/reportify.hpp" +#include "runtime/python_bindings.hpp" #include "shared_utils/sparse_matmul_utils.hpp" #include "tt_torch_device/python_bindings.hpp" #include "utils/ordered_associative_containers/ordered_map.hpp" #include "utils/signal_handlers.hpp" + namespace tt { PYBIND11_MODULE(_C, m) { @@ -116,6 +119,9 @@ PYBIND11_MODULE(_C, m) { py::module_ m_torch_device = m.def_submodule("torch_device", "TT Torch Device"); TorchDeviceModule(m_torch_device); + py::module m_runtime = m.def_submodule("runtime", "Submodule defining runtime functions"); + RuntimeModule(m_runtime); + py::enum_(m, "MathFidelity") .value("LoFi", tt::MathFidelity::LoFi) .value("HiFi2", tt::MathFidelity::HiFi2) @@ -178,7 +184,8 @@ PYBIND11_MODULE(_C, m) { py::arg("op_intermediates_to_save") = std::vector{}, py::arg("use_interactive_placer") = true, py::arg("enable_device_tilize") = false); - m.def("run_lower_to_mlir_passes", &run_lower_to_mlir_passes); + m.def("run_pre_lowering_passes", &run_pre_lowering_passes); + m.def("run_mlir_compiler", &passes::run_mlir_compiler); m.def( "dump_graph", diff --git a/pybuda/csrc/runtime/CMakeLists.txt b/pybuda/csrc/runtime/CMakeLists.txt new file mode 100644 index 000000000..ea9c4fe32 --- /dev/null +++ b/pybuda/csrc/runtime/CMakeLists.txt @@ -0,0 +1,4 @@ +add_library(runtime STATIC runtime.cpp tt_device.cpp python_bindings.cpp) +add_dependencies(runtime build_tt_mlir) + +target_compile_options(runtime PRIVATE ${STATIC_LIB_FLAGS} ${PYBUDA_CSRC_CFLAGS}) diff --git a/pybuda/csrc/runtime/python_bindings.cpp b/pybuda/csrc/runtime/python_bindings.cpp new file mode 100644 index 000000000..b19c73e57 --- /dev/null +++ b/pybuda/csrc/runtime/python_bindings.cpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "runtime/python_bindings.hpp" +#include "runtime/runtime.hpp" +#include "tt/runtime/types.h" + +namespace tt { + +void RuntimeModule(py::module &m_runtime) +{ + py::class_(m_runtime, "Binary") + .def("get_program_inputs", &runtime::Binary::getProgramInputs) + .def("get_program_outputs", &runtime::Binary::getProgramOutputs); + m_runtime.def("run_binary", tt::run_binary); +} + +} // namespace tt diff --git a/pybuda/csrc/runtime/python_bindings.hpp b/pybuda/csrc/runtime/python_bindings.hpp new file mode 100644 index 000000000..ef0fa90e5 --- /dev/null +++ b/pybuda/csrc/runtime/python_bindings.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#include "pybind11/pybind11.h" +#include +#include +#pragma clang diagnostic pop +namespace py = pybind11; + +namespace tt { + +void RuntimeModule(py::module &m_runtime); + +} // namespace tt diff --git a/pybuda/csrc/runtime/runtime.cpp b/pybuda/csrc/runtime/runtime.cpp new file mode 100644 index 000000000..f17fff4b7 --- /dev/null +++ b/pybuda/csrc/runtime/runtime.cpp @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "runtime.hpp" +#include + +#include "tt_device.hpp" +#include "utils/logger.hpp" +#include "tt/runtime/runtime.h" + +namespace tt { + +static target::DataType torch_scalar_type_to_dt(torch::ScalarType st) +{ + switch (st) + { + case torch::ScalarType::Byte: return target::DataType::UInt8; + case torch::ScalarType::Char: return target::DataType::UInt8; + case torch::ScalarType::Short: return target::DataType::UInt16; + case torch::ScalarType::Int: return target::DataType::UInt32; + case torch::ScalarType::Long: return target::DataType::UInt32; + case torch::ScalarType::Half: return target::DataType::Float16; + case torch::ScalarType::Float: return target::DataType::Float32; + // case torch::ScalarType::Double: + // case torch::ScalarType::ComplexHalf: + // case torch::ScalarType::ComplexFloat: + // case torch::ScalarType::ComplexDouble: + // case torch::ScalarType::Bool: + case torch::ScalarType::BFloat16: return target::DataType::BFloat16; + default: break; + } + + log_fatal(LogTTDevice, "Unhandled dtype {}", st); +} + +static torch::ScalarType dt_to_torch_scalar_type(target::DataType df) +{ + switch (df) + { + case target::DataType::UInt8: return torch::ScalarType::Byte; + case target::DataType::UInt16: return torch::ScalarType::Short; + case target::DataType::UInt32: return torch::ScalarType::Int; + case target::DataType::Float16: return torch::ScalarType::Half; + case target::DataType::Float32: return torch::ScalarType::Float; + case target::DataType::BFloat16: return torch::ScalarType::BFloat16; + default: break; + } + + log_fatal(LogTTDevice, "Unhandled dtype {}", df); +} + +template +std::vector as_vec_int64(std::vector const& vec) +{ + std::vector result; + result.reserve(vec.size()); + for (auto const& v : vec) + { + result.push_back(v); + } + return result; +} + +static runtime::Tensor create_tensor(const torch::Tensor& tensor) +{ + auto data = std::shared_ptr( + tensor.data_ptr(), + [tensor](void*) { (void)tensor; } // Capture tensor by value to increase ref count and keep it alive + ); + + auto shape = std::vector(tensor.sizes().begin(), tensor.sizes().end()); + auto stride = std::vector(tensor.strides().begin(), tensor.strides().end()); + + return runtime::createTensor( + data, + shape, + stride, + tensor.element_size(), + torch_scalar_type_to_dt(tensor.scalar_type())); +} + +runtime::Binary load_binary_from_file(std::string const& filename) +{ + runtime::Binary binary = tt::runtime::Binary::loadFromPath(filename.c_str()).handle; + return binary; +} + +std::vector run_binary_from_file(std::string const& filename, int program_idx, std::vector const& inputs) +{ + auto binary = load_binary_from_file(filename); + + return run_binary(binary, program_idx, inputs); +} + +std::vector run_binary(runtime::Binary &binary, int program_idx, std::vector const& inputs) +{ + auto& system = TTSystem::get_system(); + + for (auto &device : system.devices) + { + if (!device->is_open()) + { + device->open_device(); + } + } + + // For now, we only support a single device. + auto& tt_device = system.devices[0]; + if (!tt_device->is_open()) + { + log_fatal(LogTTDevice, "Failed to open device"); + } + + auto& device = *tt_device->rt_device; + + std::vector rt_inputs; + for (auto const& input : inputs) + { + rt_inputs.emplace_back(create_tensor(input)); + } + + std::vector outputs; + std::vector rt_outputs; + std::vector output_descs = binary.getProgramOutputs(program_idx); + outputs.reserve(output_descs.size()); + for (auto const& desc : output_descs) + { + std::vector shape = as_vec_int64(desc.shape); + std::vector stride = as_vec_int64(desc.stride); + + torch::Tensor output = at::empty_strided(shape, stride, dt_to_torch_scalar_type(desc.dataType)); + outputs.emplace_back(std::move(output)); + rt_outputs.emplace_back(create_tensor(outputs.back())); + } + + runtime::Event _ = runtime::submit(device, binary, program_idx, rt_inputs, rt_outputs); + + return outputs; +} + +} // namespace tt diff --git a/pybuda/csrc/runtime/runtime.hpp b/pybuda/csrc/runtime/runtime.hpp new file mode 100644 index 000000000..42953cd36 --- /dev/null +++ b/pybuda/csrc/runtime/runtime.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include +#include +#include "tt/runtime/types.h" + +namespace tt { + +// Entry point for invoking tt-mlir runtime and running the binary on the device. +std::vector run_binary(runtime::Binary& binary, int program_idx, std::vector const& inputs); + +// Helper function to run the binary from the file - might be useful for testing/debugging. +std::vector run_binary_from_file(std::string const& filename, int program_idx, std::vector const& inputs); + +} // namespace tt + diff --git a/pybuda/csrc/runtime/tt_device.cpp b/pybuda/csrc/runtime/tt_device.cpp new file mode 100644 index 000000000..f6a47d919 --- /dev/null +++ b/pybuda/csrc/runtime/tt_device.cpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "tt_device.hpp" +#include "utils/assert.hpp" +#include "utils/logger.hpp" + +#include "tt/runtime/runtime.h" + +namespace tt { + +TTSystem detect_available_devices() { + auto [system_desc, chip_ids] = runtime::getCurrentSystemDesc(); + + std::vector> devices; + int logical_device_index = 0; + ARCH arch = ARCH::Invalid; + for (std::uint32_t chip_desc_index : *system_desc->chip_desc_indices()) + { + target::ChipDesc const* chip_desc = system_desc->chip_descs()->Get(chip_desc_index); + target::ChipCapability chip_capabilities = system_desc->chip_capabilities()->Get(logical_device_index); + + bool mmio = bool(chip_capabilities & target::ChipCapability::HostMMIO); + if (not mmio) + { + continue; + } + + switch(chip_desc->arch()) + { + case target::Arch::Grayskull: arch = ARCH::GRAYSKULL; break; + case target::Arch::Wormhole_b0: arch = ARCH::WORMHOLE_B0; break; + case target::Arch::Blackhole: arch = ARCH::BLACKHOLE; break; + default: log_fatal(LogTTDevice, "Unknown chip type {}", chip_desc->arch()); + } + + auto device = std::make_shared(std::nullopt, system_desc, arch, mmio, logical_device_index); + devices.push_back(device); + ++logical_device_index; + } + + return TTSystem{system_desc, chip_ids, devices}; +} + +TTSystem& TTSystem::get_system() { + static TTSystem system = detect_available_devices(); + return system; +} + +void TTDevice::open_device() { + TT_ASSERT(!is_open()); + rt_device = runtime::openDevice({index}); +} + +void TTDevice::close_device() { + TT_ASSERT(is_open()); + runtime::closeDevice(rt_device.value()); + rt_device.reset(); +} + +} // namespace tt diff --git a/pybuda/csrc/runtime/tt_device.hpp b/pybuda/csrc/runtime/tt_device.hpp new file mode 100644 index 000000000..338453022 --- /dev/null +++ b/pybuda/csrc/runtime/tt_device.hpp @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include + +#include "pybuda/csrc/backend_api/arch_type.hpp" +#include "tt/runtime/types.h" + +namespace tt +{ + +struct TTDevice +{ + std::optional rt_device; + ARCH arch; + bool mmio; + int index; + + // TODO: These don't seem to belong here + std::map> input_runtime_transforms; + std::map>> input_tile_bcast_dims; + std::map> output_runtime_transforms; + std::unordered_map> subgraph_to_tensor_uid_on_device; + + TTDevice( + std::optional rt_device, + runtime::SystemDesc system_desc, + ARCH arch, + bool mmio, + int index) : + rt_device(rt_device), + arch(arch), + mmio(mmio), + index(index) + { + } + + TTDevice(const TTDevice&) = delete; + TTDevice& operator=(const TTDevice&) = delete; + + bool is_open() const + { + return rt_device.has_value(); + } + + void open_device(); + void close_device(); +}; + +struct TTSystem +{ + runtime::SystemDesc system_desc; + std::vector chip_ids; + std::vector> devices; + + TTSystem(const TTSystem&) = delete; + TTSystem& operator=(const TTSystem&) = delete; + + ~TTSystem() + { + close_devices(); + } + + void close_devices() + { + for (auto& device : devices) + { + if (device->is_open()) + { + device->close_device(); + } + } + } + + static TTSystem& get_system(); +}; + +TTSystem detect_available_devices(); + +} // namespace tt + diff --git a/pybuda/csrc/tt_torch_device/python_bindings.cpp b/pybuda/csrc/tt_torch_device/python_bindings.cpp index e2d567dc4..bf86a31f7 100644 --- a/pybuda/csrc/tt_torch_device/python_bindings.cpp +++ b/pybuda/csrc/tt_torch_device/python_bindings.cpp @@ -39,11 +39,9 @@ void TorchDeviceModule(py::module &m_torch_device) py::class_tt_device (m_torch_device, "TTDevice"); tt_device.def_readonly("arch", &tt::TTDevice::arch) .def_readonly("mmio", &tt::TTDevice::mmio) - .def_readonly("chip_ids", &tt::TTDevice::chip_ids) .def_readonly("input_runtime_transforms", &tt::TTDevice::input_runtime_transforms) .def_readonly("input_tile_bcast_dims", &tt::TTDevice::input_tile_bcast_dims) .def_readonly("output_runtime_transforms", &tt::TTDevice::output_runtime_transforms) - .def_readonly("system_desc", &tt::TTDevice::system_desc) .def_property_readonly("cluster_yaml", &tt::get_device_cluster_yaml) .def("torch_device", &tt::torch_device) .def("str", &tt::to_string) @@ -55,6 +53,4 @@ void TorchDeviceModule(py::module &m_torch_device) m_torch_device.def("unique_id", tt::unique_id); } - - } diff --git a/pybuda/csrc/tt_torch_device/torch_device_impl.cpp b/pybuda/csrc/tt_torch_device/torch_device_impl.cpp index b38456f50..f926097a3 100644 --- a/pybuda/csrc/tt_torch_device/torch_device_impl.cpp +++ b/pybuda/csrc/tt_torch_device/torch_device_impl.cpp @@ -29,12 +29,7 @@ constexpr inline c10::DispatchKey DispatchKeyTT = c10::DispatchKey::PrivateUse1; class TorchDeviceImpl final : public c10::impl::DeviceGuardImplInterface { public: - TorchDeviceImpl(std::vector const& tt_devices) : tt_devices(tt_devices) {} - - // TODO: check if this is ok... not sure if we should open devices in this class... - ~TorchDeviceImpl() override { - close_devices(); - } + TorchDeviceImpl(const TTSystem& system) : tt_devices(system.devices) {} // Torch overrides virtual c10::DeviceType type() const override { return TT; } @@ -71,7 +66,7 @@ class TorchDeviceImpl final : public c10::impl::DeviceGuardImplInterface // TT specific static TorchDeviceImpl& get() { - static TorchDeviceImpl tt_device_impl(query_available_tt_devices()); + static TorchDeviceImpl tt_device_impl(TTSystem::get_system()); return tt_device_impl; } @@ -79,27 +74,19 @@ class TorchDeviceImpl final : public c10::impl::DeviceGuardImplInterface int get_next_unique_id() { return next_id++; } - TTDevice getTTDevice() const + std::shared_ptr getTTDevice() const { TT_ASSERT(current_device.index() < (int)tt_devices.size()); return tt_devices[current_device.index()]; } - const TTDevice& getDefaultTTDevice() const + const std::shared_ptr& getDefaultTTDevice() const { TT_ASSERT(not tt_devices.empty()); return tt_devices.front(); } - void close_devices() - { - for (auto& tt_device : tt_devices) - { - runtime::closeDevice(tt_device.rt_device); - } - } - - std::vector getTTDevices() const { return tt_devices; } + std::vector> getTTDevices() const { return tt_devices; } std::map registered_output_transforms; std::vector ordered_input_trasforms; @@ -108,15 +95,15 @@ class TorchDeviceImpl final : public c10::impl::DeviceGuardImplInterface private: mutable c10::Device current_device = c10::Device(TT, 0); mutable c10::Stream current_stream = c10::Stream(c10::Stream::UNSAFE, c10::Device(TT, 0), 0); - std::vector tt_devices; + std::vector> tt_devices; int next_id = 0; }; // register backend c10::impl::DeviceGuardImplRegistrar tt_device_impl_reg(TT, &TorchDeviceImpl::get()); -const TTDevice& get_default_tt_device() { return TorchDeviceImpl::get().getDefaultTTDevice();} -std::vector get_available_tt_devices() { return TorchDeviceImpl::get().getTTDevices(); } +const std::shared_ptr& get_default_tt_device() { return TorchDeviceImpl::get().getDefaultTTDevice();} +std::vector> get_available_tt_devices() { return TorchDeviceImpl::get().getTTDevices(); } struct Mallocator final : public c10::Allocator { @@ -546,16 +533,16 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) if (ops_registered) return; ops_registered = true; - m.impl("aten::empty.memory_format", &tt::empty); - m.impl("aten::empty_strided", &tt::empty_strided); - m.impl("aten::_copy_from", &tt::_copy_from); + // m.impl("aten::empty.memory_format", &tt::empty); + // m.impl("aten::empty_strided", &tt::empty_strided); + // m.impl("aten::_copy_from", &tt::_copy_from); // m.impl("aten::_to_copy", &tt::_to_copy); // m.impl("aten::to", &tt::to); - m.impl("aten::_copy_from_and_resize", &tt::_copy_from_and_resize); - m.impl("aten::_reshape_alias", &tt::_reshape_alias); + // m.impl("aten::_copy_from_and_resize", &tt::_copy_from_and_resize); + // m.impl("aten::_reshape_alias", &tt::_reshape_alias); // m.impl("aten::as_strided", &tt::as_strided); - m.impl("aten::index.Tensor_out", &tt::index_outf); - m.impl("aten::view", &tt::view); + // m.impl("aten::index.Tensor_out", &tt::index_outf); + // m.impl("aten::view", &tt::view); } bool fallback_registered = false; @@ -567,7 +554,3 @@ TORCH_LIBRARY_IMPL(_, PrivateUse1, m) m.fallback(torch::CppFunction::makeFromBoxedFunction<&tt::fallback>()); } -void tt::close_devices() -{ - tt::TorchDeviceImpl::get().close_devices(); -} diff --git a/pybuda/csrc/tt_torch_device/tt_device.cpp b/pybuda/csrc/tt_torch_device/tt_device.cpp index 002a2ac1d..72a87c20e 100644 --- a/pybuda/csrc/tt_torch_device/tt_device.cpp +++ b/pybuda/csrc/tt_torch_device/tt_device.cpp @@ -11,6 +11,7 @@ #include "pybuda/csrc/lower_to_buda/common.hpp" #include "tt/runtime/runtime.h" +#include "tt/runtime/types.h" #include "utils/assert.hpp" #include "utils/env.hpp" #include "utils/logger.hpp" @@ -191,7 +192,12 @@ std::vector dispatch( rt_outputs.emplace_back(create_tensor(outputs.back())); } - runtime::Event event = runtime::submit(device.rt_device, binary, program_idx, rt_inputs, rt_outputs); + if (!device.rt_device.has_value()) + { + device.open_device(); + } + + runtime::Event event = runtime::submit(device.rt_device.value(), binary, program_idx, rt_inputs, rt_outputs); (void)event; // Clear old tensor uids and update with new ones @@ -240,45 +246,6 @@ std::vector dispatch( return outputs; } -std::vector query_available_tt_devices() -{ - static std::shared_ptr context = std::make_shared(); - std::vector d; - - auto [system_desc, device_ids] = runtime::getCurrentSystemDesc(); - - int logical_device_index = 0; - ARCH arch = ARCH::Invalid; - for (std::uint32_t chip_desc_index : *system_desc->chip_desc_indices()) - { - target::ChipDesc const* chip_desc = system_desc->chip_descs()->Get(chip_desc_index); - target::ChipCapability chip_capabilities = system_desc->chip_capabilities()->Get(logical_device_index); - bool mmio = bool(chip_capabilities & target::ChipCapability::HostMMIO); - if (not mmio) - { - continue; - } - switch(chip_desc->arch()) - { - case target::Arch::Grayskull: arch = ARCH::GRAYSKULL; break; - case target::Arch::Wormhole_b0: arch = ARCH::WORMHOLE_B0; break; - case target::Arch::Blackhole: arch = ARCH::BLACKHOLE; break; - default: log_fatal(LogTTDevice, "Unknown chip type {}", chip_desc->arch()); - } - ++logical_device_index; - } - - if (arch == ARCH::Invalid) - log_fatal(LogTTDevice, "No available devices detected (To run with golden device, set PYBUDA_DEVMODE=1)"); - - runtime::Device rt_device = runtime::openDevice(device_ids); - d.emplace_back(rt_device, system_desc, device_ids, arch, true, 0, context); - - log_debug(LogTTDevice, "Available devices:"); - for (int i = 0; i < (int)d.size(); ++i) log_debug(LogTTDevice, " [{}] {}", i, d[i].arch); - return d; -} - std::string get_device_cluster_yaml(TTDevice const&) { return "";} //TODO } std::string to_string(TTDevice const& d) @@ -339,6 +306,7 @@ std::vector original_shape(const torch::Tensor& tensor) return shape; } + int unique_id(const torch::Tensor& tensor) { auto impl = tensor.unsafeGetTensorImpl(); diff --git a/pybuda/csrc/tt_torch_device/tt_device.hpp b/pybuda/csrc/tt_torch_device/tt_device.hpp index 8b1011869..e638b908b 100644 --- a/pybuda/csrc/tt_torch_device/tt_device.hpp +++ b/pybuda/csrc/tt_torch_device/tt_device.hpp @@ -16,6 +16,7 @@ #include #include "pybuda/csrc/backend_api/arch_type.hpp" +#include "runtime/tt_device.hpp" #include "tt/runtime/types.h" #include "utils/assert.hpp" #include "utils/env.hpp" @@ -88,48 +89,13 @@ struct TTContext using Fence = std::uint64_t; using ResourceID = std::uint64_t; -// 1to1 mapping of physical devices plugged into this machine and TTDevice -struct TTDevice -{ - runtime::Device rt_device; - runtime::SystemDesc system_desc; - std::vector chip_ids; - ARCH arch; - bool mmio; - int index; - std::shared_ptr context; - std::map> input_runtime_transforms; - std::map>> input_tile_bcast_dims; - std::map> output_runtime_transforms; - bool initialized = false; - std::unordered_map> subgraph_to_tensor_uid_on_device; - - TTDevice( - runtime::Device rt_device, - runtime::SystemDesc system_desc, - std::vector chip_ids, - ARCH arch, - bool mmio, - int index, - std::shared_ptr context) : - rt_device(rt_device), - system_desc(system_desc), - chip_ids(chip_ids), - arch(arch), - mmio(mmio), - index(index), - context(context) - { - } -}; - using FreePytorchTensorDescFn = void(void*); void register_output_runtime_transform(torch::Tensor const& tensor, std::string transform); void register__ordered_input_runtime_transforms(std::vector input_transforms); std::string get_runtime_transform(torch::Tensor const& tensor); std::vector query_available_tt_devices(); -const TTDevice& get_default_tt_device(); -std::vector get_available_tt_devices(); +const std::shared_ptr& get_default_tt_device(); +std::vector> get_available_tt_devices(); std::string device_type_name(c10::DeviceType type, bool lower_case = false); torch::Device torch_device_at_index(std::int64_t index); torch::Tensor empty_strided( @@ -156,6 +122,8 @@ int unique_id(const torch::Tensor& tensor); torch::Tensor narrow_to_pytorch(const torch::Tensor& tensor, torch::IntArrayRef original_shape); std::vector original_shape(const torch::Tensor& tensor); +std::shared_ptr load_binary_from_file(std::string const& filename); + template inline T align_up_tile(T d) { @@ -163,6 +131,4 @@ inline T align_up_tile(T d) return static_cast(d - (d % kTileDim) + kTileDim); } -void close_devices(); - } // namespace tt diff --git a/pybuda/pybuda/_C/__init__.pyi b/pybuda/pybuda/_C/__init__.pyi index 39d84ff84..cbddcae24 100644 --- a/pybuda/pybuda/_C/__init__.pyi +++ b/pybuda/pybuda/_C/__init__.pyi @@ -1,6 +1,7 @@ -from . import autograd as autograd, backend_api as backend_api, balancer as balancer, graph as graph, pattern_matcher as pattern_matcher, scheduler as scheduler, torch_device as torch_device -from typing import ClassVar, Dict, List, Optional, Tuple, Union +from . import autograd as autograd, graph as graph, torch_device as torch_device +from typing import ClassVar +BLACKHOLE: Arch Backward: NodeEpochType Bfp2: DataFormat Bfp2_b: DataFormat @@ -12,11 +13,14 @@ Float16: DataFormat Float16_b: DataFormat Float32: DataFormat Forward: NodeEpochType +GRAYSKULL: Arch HiFi2: MathFidelity HiFi3: MathFidelity HiFi4: MathFidelity +Int32: DataFormat Int8: DataFormat Invalid: MathFidelity +JAWBRIDGE: Arch Lf8: DataFormat LoFi: MathFidelity Optimizer: NodeEpochType @@ -25,48 +29,54 @@ RawUInt32: DataFormat RawUInt8: DataFormat UInt16: DataFormat VERSION: int +WORMHOLE: Arch +WORMHOLE_B0: Arch k_dim: int class AMPNodeProperties: - def __init__(self, op_type: Optional[str] = ..., epoch_type: Optional[NodeEpochType] = ..., output_df: Optional[DataFormat] = ..., intermediate_df: Optional[DataFormat] = ..., accumulate_df: Optional[DataFormat] = ..., math_fidelity: Optional[MathFidelity] = ..., name_regex_match: Optional[str] = ..., input_df: Optional[Union[Dict[int,Tuple[DataFormat,bool]],DataFormat]] = ..., is_gradient_op: Optional[bool] = ..., input_parameter_indices_to_optimize: Optional[List[Tuple[int,int]]] = ...) -> None: ... + def __init__(self, op_type: str | None = ..., epoch_type: NodeEpochType | None = ..., output_df: DataFormat | None = ..., intermediate_df: DataFormat | None = ..., accumulate_df: DataFormat | None = ..., math_fidelity: MathFidelity | None = ..., name_regex_match: str | None = ..., input_df: dict[int, tuple[DataFormat, bool]] | DataFormat | None | None = ..., is_gradient_op: bool | None = ..., input_parameter_indices_to_optimize: list[tuple[int, int]] | None = ...) -> None: ... def from_json(self) -> AMPNodeProperties: ... def to_json(self) -> json: ... - def __getstate__(self) -> tuple: ... - def __setstate__(self, arg0: tuple) -> None: ... @property - def accumulate_df(self) -> Optional[DataFormat]: ... + def accumulate_df(self) -> DataFormat | None: ... @property - def epoch_type(self) -> Optional[NodeEpochType]: ... + def epoch_type(self) -> NodeEpochType | None: ... @property - def input_df(self) -> Optional[Union[Dict[int,Tuple[DataFormat,bool]],DataFormat]]: ... + def input_df(self) -> dict[int, tuple[DataFormat, bool]] | DataFormat | None | None: ... @property - def input_parameter_indices_to_optimize(self) -> Optional[List[Tuple[int,int]]]: ... + def input_parameter_indices_to_optimize(self) -> list[tuple[int, int]] | None: ... @property - def intermediate_df(self) -> Optional[DataFormat]: ... + def intermediate_df(self) -> DataFormat | None: ... @property - def is_gradient_op(self) -> Optional[bool]: ... + def is_gradient_op(self) -> bool | None: ... @property - def math_fidelity(self) -> Optional[MathFidelity]: ... + def math_fidelity(self) -> MathFidelity | None: ... @property - def name_regex_match(self) -> Optional[str]: ... + def name_regex_match(self) -> str | None: ... @property - def op_type(self) -> Optional[str]: ... + def op_type(self) -> str | None: ... @property - def output_df(self) -> Optional[DataFormat]: ... + def output_df(self) -> DataFormat | None: ... -class Block: - def __init__(self) -> None: ... - -class Blocks: - def __init__(self) -> None: ... - -class BudaNetlist: - def __init__(self) -> None: ... - def append_comment(self, arg0: str) -> None: ... - def dump_to_yaml(self) -> str: ... - -class BudaNetlistConfig: - def __init__(self) -> None: ... +class Arch: + __members__: ClassVar[dict] = ... # read-only + BLACKHOLE: ClassVar[Arch] = ... + GRAYSKULL: ClassVar[Arch] = ... + Invalid: ClassVar[Arch] = ... + JAWBRIDGE: ClassVar[Arch] = ... + WORMHOLE: ClassVar[Arch] = ... + WORMHOLE_B0: ClassVar[Arch] = ... + __entries: ClassVar[dict] = ... + def __init__(self, value: int) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... class DataFormat: __members__: ClassVar[dict] = ... # read-only @@ -79,6 +89,7 @@ class DataFormat: Float16: ClassVar[DataFormat] = ... Float16_b: ClassVar[DataFormat] = ... Float32: ClassVar[DataFormat] = ... + Int32: ClassVar[DataFormat] = ... Int8: ClassVar[DataFormat] = ... Invalid: ClassVar[DataFormat] = ... Lf8: ClassVar[DataFormat] = ... @@ -91,29 +102,15 @@ class DataFormat: def from_json(self) -> DataFormat: ... def to_json(self) -> str: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... -class DramQueueConfigOverride: - def __init__(self, arg0: Optional[int], arg1: Optional[int]) -> None: ... - def from_json(self) -> DramQueueConfigOverride: ... - def to_json(self) -> Dict[str,Optional[int]]: ... - def __getstate__(self) -> tuple: ... - def __setstate__(self, arg0: tuple) -> None: ... - -class InsertionInstruction: - def __init__(self, src: str, dest: str, hoist_tms: bool, input_id: Optional[int] = ..., fork_id: Optional[int] = ..., user_defined: bool = ...) -> None: ... - def insert(self, arg0: graph.Graph) -> None: ... - def unique_id(self) -> Tuple[str,str,int,int,bool]: ... - class MathFidelity: __members__: ClassVar[dict] = ... # read-only HiFi2: ClassVar[MathFidelity] = ... @@ -126,12 +123,10 @@ class MathFidelity: def from_json(self) -> MathFidelity: ... def to_json(self) -> str: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -145,87 +140,53 @@ class NodeEpochType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... -class NopInsertionInstruction(InsertionInstruction): - def __init__(self, src: str, dest: str, hoist_tms: bool, nop_count: int = ..., input_id: Optional[int] = ..., fork_id: Optional[int] = ..., user_defined: bool = ..., mergeable: bool = ..., daisy_chain: bool = ..., request_merge: bool = ...) -> None: ... - def from_json(self) -> NopInsertionInstruction: ... - def to_json(self) -> Dict[str,Union[str,bool,int,Optional[int]]]: ... - def unique_id(self) -> Tuple[str,str,int,int,bool]: ... - def __getstate__(self) -> tuple: ... - def __setstate__(self, arg0: tuple) -> None: ... - -class PostPlacerConfig: - def __init__(self, device_config: backend_api.DeviceConfig, microbatch_size: int, microbatch_count: int, enable_t_streaming: bool, input_queues_on_host: bool, output_queues_on_host: bool, manual_dram_queue_placement: Dict[str,DramQueueConfigOverride], fork_join_tiles_treshold: int, output_queue_multiplier: int, input_queue_multiplier: int, enable_cross_chip_buffering: bool, placement_algorithm: placer.DRAMPlacementAlgorithm) -> None: ... - -class PostPlacerResults: - def __init__(self, *args, **kwargs) -> None: ... - @property - def allocated_blocks(self) -> List[List[Blocks]]: ... - @property - def current_host_address(self) -> int: ... - @property - def ins_instructions(self) -> Dict[Tuple[str,str,int,int,bool],InsertionInstruction]: ... - @property - def perf_model_results(self) -> Dict[str,float]: ... - -class QueueInsertionInstruction(InsertionInstruction): - def __init__(self, src: str, dest: str, hoist_tms: bool, num_entries: int, queue_size: int, input_id: Optional[int] = ..., fork_id: Optional[int] = ..., user_defined: bool = ...) -> None: ... - def unique_id(self) -> Tuple[str,str,int,int,bool]: ... - def __getstate__(self) -> tuple: ... - def __setstate__(self, arg0: tuple) -> None: ... - class SparseBUDA: def __init__(self, *args, **kwargs) -> None: ... def get_sparse_tile_ptr_bits(self, arg0: int, arg1: int, arg2: int) -> int: ... - def get_sparse_tiles_and_encodings(self, arg0: int) -> Tuple[List[List[float]],List[List[int]],List[int],List[int],List[int]]: ... + def get_sparse_tiles_and_encodings(self, arg0: int) -> tuple[list[list[float]], list[list[int]], list[int], list[int], list[int]]: ... + def get_sparse_ublock_idx_bits(self, arg0: int, arg1: int, arg2: int) -> int: ... @property def bcast_factor(self) -> int: ... @property - def sparse_indices(self) -> Any: ... + def sparse_indices(self): ... @property - def sparse_shape(self) -> List[int]: ... + def sparse_shape(self) -> list[int]: ... @property def zdim(self) -> int: ... class SparseCOO: - def __init__(self, rows: List[int], cols: List[int], vals: List[float], shape: List[int]) -> None: ... + def __init__(self, rows: list[int], cols: list[int], vals: list[float], shape: list[int]) -> None: ... @property - def cols(self) -> List[int]: ... + def cols(self) -> list[int]: ... @property - def rows(self) -> List[int]: ... + def rows(self) -> list[int]: ... @property - def shape(self) -> List[int]: ... + def shape(self) -> list[int]: ... @property - def vals(self) -> List[float]: ... + def vals(self) -> list[float]: ... class UnsupportedHWOpsError(Exception): ... -def calculate_splice_buda_attrs(org_buda_attrs: List[int] = ..., splice_type: str = ..., input_shape_z: int = ..., input_shape_rt: int = ..., input_shape_ct: int = ..., dim: int = ..., ublock_order_cuts_dim: bool = ..., index: int = ..., length: int = ..., stride: int = ..., grid_r: int = ..., grid_c: int = ..., ublock_rt: int = ..., ublock_ct: int = ..., t_stream_factor_t_dim: int = ...) -> Tuple[int,int,int]: ... -def compress_sparse_tensor_and_strip_info(arg0: List[SparseCOO], arg1: int, arg2: int) -> SparseBUDA: ... -def dump_epoch_id_graphs(graph: graph.Graph, test_name: str, graph_name: str, placer_solution: placer.PlacerSolution, balancer_solution: balancer.BalancerSolution = ...) -> None: ... -def dump_epoch_type_graphs(graph: graph.Graph, test_name: str, graph_name: str, placer_solution: placer.PlacerSolution = ..., balancer_solution: balancer.BalancerSolution = ...) -> None: ... -def dump_graph(graph: graph.Graph, test_name: str, graph_name: str, placer_solution: placer.PlacerSolution = ..., balancer_solution: balancer.BalancerSolution = ...) -> None: ... -def is_subset_of_instructions(ins_instructions: Dict[Tuple[str,str,int,int,bool],InsertionInstruction] = ..., previous_instructions: Dict[Tuple[str,str,int,int,bool],InsertionInstruction] = ...) -> Tuple[bool,int,int]: ... -def link_past_cache_ios(arg0: graph.Graph) -> Dict[str,int]: ... -def lower_to_buda_netlist(arg0: graph.Graph, arg1: str, arg2: placer.PlacerSolution, arg3: balancer.BalancerSolution, arg4: List[int], arg5: backend_api.DeviceConfig) -> BudaNetlist: ... -def merge_netlists(arg0: List[BudaNetlist]) -> BudaNetlist: ... +def compress_sparse_tensor_and_strip_info(arg0: list[SparseCOO], arg1: int, arg2: int) -> SparseBUDA: ... +def dump_epoch_id_graphs(graph: graph.Graph, test_name: str, graph_name: str) -> None: ... +def dump_epoch_type_graphs(graph: graph.Graph, test_name: str, graph_name: str) -> None: ... +def dump_graph(graph: graph.Graph, test_name: str, graph_name: str) -> None: ... +def link_past_cache_ios(arg0: graph.Graph) -> dict[str, int]: ... +def move_index_to_mm_weights(arg0: graph.Graph) -> None: ... def run_consteval_graph_pass(arg0: graph.Graph) -> None: ... -def run_optimization_graph_passes(arg0: graph.Graph, arg1: backend_api.DeviceConfig) -> None: ... -def run_placer_buda_passes(arg0: graph.Graph, arg1: balancer.BalancerConfig, arg2: Dict[str, int], arg3: dict) -> Tuple[balancer.BalancerSolution, bool]: ... -def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> List[Tuple[int, int]]: ... -def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: List[Tuple[List[Tuple[str, List[int], List[int]]], Dict[str, List[int]]]]) -> Tuple[List[Tuple[int, int]], Dict[str, int]]: ... -def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> List[Tuple[int, int]]: ... -def run_post_placer_buda_passes(arg0: graph.Graph, arg1: str, arg2: backend_api.DeviceConfig, arg3: placer.PlacerSolution, arg4: PostPlacerConfig, arg5: balancer.BalancerSolution, arg6: Dict[Tuple[str, str, int, int, bool], InsertionInstruction], arg7: List[List[Blocks]], arg8: int) -> PostPlacerResults: ... -def run_lower_to_mlir_passes(arg0: graph.Graph) -> None: ... -def run_pre_netlist_generation_buda_passes(arg0: graph.Graph, arg1: str, arg2: backend_api.DeviceConfig, arg3: Dict[str, object], arg4: placer.PlacerSolution, arg5: PostPlacerConfig, arg6: balancer.BalancerSolution, arg7: List[List[Blocks]], arg8: int) -> None: ... -def run_pre_placer_buda_passes(graph: graph.Graph, scheduler_config: scheduler.SchedulerConfig, device_config: backend_api.DeviceConfig, chip_ids: List[int] = ..., op_names_to_chip_break: List[List[str]] = ..., op_names_to_epoch_break: List[List[str]] = ..., op_names_dont_fuse: List[str] = ..., op_names_manual_fuse: List[str] = ..., fracture_chip_id_assignments: Dict[str, int] = ..., default_df_override: Optional[DataFormat] = ..., default_accumulate_df: Optional[DataFormat] = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., ins_instructions: Dict[Tuple[str, str, int, int, bool], InsertionInstruction] = ..., insert_queues: List[Tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: List[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> Tuple[graph.Graph, placer.PlacerConfigUpdate]: ... +def run_mlir_compiler(arg0: graph.Graph) -> runtime.Binary: ... +def run_optimization_graph_passes(arg0: graph.Graph) -> None: ... +def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... +def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ... +def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... +def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ... +def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ... diff --git a/pybuda/pybuda/_C/autograd.pyi b/pybuda/pybuda/_C/autograd.pyi index ee8a7f069..f53083955 100644 --- a/pybuda/pybuda/_C/autograd.pyi +++ b/pybuda/pybuda/_C/autograd.pyi @@ -1,5 +1,5 @@ import pybuda._C.graph -from typing import List, Union, overload +from typing import overload class AutogradConfig: def __init__(self, recompute: bool = ..., optimizer: object = ...) -> None: ... @@ -10,13 +10,13 @@ class AutogradContext: def constant(self, arg0: int) -> pybuda._C.graph.NodeContext: ... @overload def constant(self, arg0: float) -> pybuda._C.graph.NodeContext: ... - def create_optimizer_op(self, type: str, operands: List[pybuda._C.graph.NodeContext], attributes=...) -> pybuda._C.graph.NodeContext: ... - def get_operands(self, arg0: pybuda._C.graph.NodeContext) -> List[pybuda._C.graph.NodeContext]: ... + def create_optimizer_op(self, type: str, operands: list[pybuda._C.graph.NodeContext], attributes=...) -> pybuda._C.graph.NodeContext: ... + def get_operands(self, arg0: pybuda._C.graph.NodeContext) -> list[pybuda._C.graph.NodeContext]: ... def get_pytorch_tensor(self, arg0: pybuda._C.graph.NodeContext) -> object: ... - def get_shape(self, arg0: pybuda._C.graph.NodeContext) -> List[int]: ... + def get_shape(self, arg0: pybuda._C.graph.NodeContext) -> list[int]: ... def input(self, *args, **kwargs): ... def loopback(self, arg0: pybuda._C.graph.NodeContext, arg1: pybuda._C.graph.NodeContext) -> None: ... - def op(self, type: Union[str, object], operands: List[pybuda._C.graph.NodeContext], attributes=...) -> pybuda._C.graph.NodeContext: ... + def op(self, type: str | object, operands: list[pybuda._C.graph.NodeContext], attributes=...) -> pybuda._C.graph.NodeContext: ... def tensor(self, arg0: object) -> pybuda._C.graph.NodeContext: ... class AutogradEngine: diff --git a/pybuda/pybuda/_C/graph.pyi b/pybuda/pybuda/_C/graph.pyi index 10381ffbc..757c7c79d 100644 --- a/pybuda/pybuda/_C/graph.pyi +++ b/pybuda/pybuda/_C/graph.pyi @@ -1,5 +1,5 @@ import pybuda._C -from typing import ClassVar, Dict, Iterator, List, Optional, Tuple, Union, overload +from typing import ClassVar, Iterator, overload C: UBlockOrder Concatenate: RuntimeTensorTransformType @@ -21,45 +21,44 @@ class Graph: def __init__(self, arg0: str) -> None: ... def clone(self) -> Graph: ... def enable_training(self) -> bool: ... - def get_constant_input_runtime_tensor_transform_constants(self) -> List[Tuple[str, object]]: ... - def get_constant_names(self) -> List[str]: ... + def get_constant_input_runtime_tensor_transform_constants(self) -> list[tuple[str, object]]: ... + def get_constant_names(self) -> list[str]: ... def get_constant_nodes(self, *args, **kwargs): ... - def get_fused_ops(self) -> List[Tuple[int, List[List[str]]]]: ... def get_input_runtime_tensor_transforms(self, *args, **kwargs): ... def get_microbatch(self) -> int: ... def get_name(self) -> str: ... def get_node_id(self, arg0: str) -> int: ... def get_node_name(self, arg0: int) -> str: ... - def get_ordered_constant_tile_dims(self) -> List[List[int]]: ... - def get_ordered_input_gradient_names(self) -> List[str]: ... - def get_ordered_input_names(self) -> List[str]: ... - def get_ordered_input_requires_grad(self) -> List[bool]: ... - def get_ordered_input_shapes(self) -> List[List[int]]: ... - def get_ordered_input_subgraph_indices(self) -> List[int]: ... - def get_ordered_input_tile_dims(self) -> List[List[int]]: ... - def get_ordered_intermediate_names(self) -> List[str]: ... - def get_ordered_intermediate_shapes(self) -> List[List[int]]: ... - def get_ordered_output_gradient_names(self) -> List[str]: ... - def get_ordered_output_names(self) -> List[str]: ... - def get_ordered_output_requires_grad(self) -> List[bool]: ... - def get_ordered_output_shapes(self) -> List[List[int]]: ... - def get_ordered_output_subgraph_indices(self) -> List[int]: ... - def get_ordered_parameter_tile_dims(self) -> List[List[int]]: ... - def get_ordered_target_names(self) -> List[str]: ... - def get_ordered_target_shapes(self) -> List[List[int]]: ... - def get_ordered_target_subgraph_indices(self) -> List[int]: ... + def get_ordered_constant_tile_dims(self) -> list[list[int]]: ... + def get_ordered_input_gradient_names(self) -> list[str]: ... + def get_ordered_input_names(self) -> list[str]: ... + def get_ordered_input_requires_grad(self) -> list[bool]: ... + def get_ordered_input_shapes(self) -> list[list[int]]: ... + def get_ordered_input_subgraph_indices(self) -> list[int]: ... + def get_ordered_input_tile_dims(self) -> list[list[int]]: ... + def get_ordered_intermediate_names(self) -> list[str]: ... + def get_ordered_intermediate_shapes(self) -> list[list[int]]: ... + def get_ordered_output_gradient_names(self) -> list[str]: ... + def get_ordered_output_names(self) -> list[str]: ... + def get_ordered_output_requires_grad(self) -> list[bool]: ... + def get_ordered_output_shapes(self) -> list[list[int]]: ... + def get_ordered_output_subgraph_indices(self) -> list[int]: ... + def get_ordered_parameter_tile_dims(self) -> list[list[int]]: ... + def get_ordered_target_names(self) -> list[str]: ... + def get_ordered_target_shapes(self) -> list[list[int]]: ... + def get_ordered_target_subgraph_indices(self) -> list[int]: ... def get_output_runtime_tensor_transforms(self, *args, **kwargs): ... def get_parameter_nodes(self, *args, **kwargs): ... def get_subgraph_id_for_node(self, arg0: int) -> int: ... - def get_tile_broadcast_dims_for_bw_input(self, arg0: int) -> List[int]: ... - def get_tile_broadcast_dims_for_input(self, arg0: int) -> List[int]: ... - def get_tile_broadcast_dims_for_target(self, arg0: int) -> List[int]: ... + def get_tile_broadcast_dims_for_bw_input(self, arg0: int) -> list[int]: ... + def get_tile_broadcast_dims_for_input(self, arg0: int) -> list[int]: ... + def get_tile_broadcast_dims_for_target(self, arg0: int) -> list[int]: ... def has_node_with_id(self, arg0: int) -> bool: ... - def nodes(self) -> List[str]: ... + def nodes(self) -> list[str]: ... def output_node_redirected(self) -> bool: ... - def register_module_inputs(self, arg0: List[int]) -> None: ... - def register_module_outputs(self, arg0: List[int], arg1: List[bool]) -> None: ... - def register_module_targets(self, arg0: List[int]) -> None: ... + def register_module_inputs(self, module_inputs: list[int], append: bool = ...) -> None: ... + def register_module_outputs(self, module_outputs: list[int], requires_grad: list[bool], append: bool = ...) -> None: ... + def register_module_targets(self, arg0: list[int]) -> None: ... def set_enable_training(self, arg0: bool) -> None: ... def set_microbatch(self, arg0: int) -> None: ... @@ -195,15 +194,15 @@ class Shape: BUDA: ClassVar[Shape.Type] = ... FREE: ClassVar[Shape.Type] = ... def __init__(self, *args, **kwargs) -> None: ... - def as_list(self) -> List[int]: ... - @classmethod - def create(cls, values: List[int]) -> Shape: ... - @classmethod - def create_buda(cls, arg0: List[int], arg1: int, arg2: int) -> Shape: ... - @classmethod - def create_with_type_from_other(cls, other: Shape, values: List[int]) -> Shape: ... - @classmethod - def from_json(cls, arg0: json) -> Shape: ... + def as_list(self) -> list[int]: ... + @staticmethod + def create(values: list[int]) -> Shape: ... + @staticmethod + def create_buda(arg0: list[int], arg1: int, arg2: int) -> Shape: ... + @staticmethod + def create_with_type_from_other(other: Shape, values: list[int]) -> Shape: ... + @staticmethod + def from_json(arg0: json) -> Shape: ... def get_tile_dim(self, *args, **kwargs): ... def get_tile_height(self) -> int: ... def get_tile_width(self) -> int: ... @@ -247,21 +246,21 @@ class UBlockOrder: def add_partial_datacopy_edge(arg0: Graph, arg1: int, arg2: int, arg3: int, arg4: int) -> None: ... def add_subgraph_io_link_edge(arg0: Graph, arg1: int, arg2: int, arg3: int, arg4: int) -> None: ... -def create_activation_input(arg0: Graph, arg1: str, arg2: List[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... +def create_activation_input(arg0: Graph, arg1: str, arg2: list[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... @overload def create_constant_input(arg0: Graph, arg1: str, arg2: float, arg3: pybuda._C.DataFormat, arg4: int) -> int: ... @overload -def create_constant_input(arg0: Graph, arg1: str, arg2: object, arg3: List[int], arg4: pybuda._C.DataFormat, arg5: int) -> int: ... +def create_constant_input(arg0: Graph, arg1: str, arg2: object, arg3: list[int], arg4: pybuda._C.DataFormat, arg5: int) -> int: ... def create_control_edge(arg0: Graph, arg1: int, arg2: int, arg3: int, arg4: int) -> None: ... -def create_data_edge(arg0: Graph, arg1: int, arg2: int, arg3: int, arg4: int, arg5: List[tuple]) -> None: ... -def create_op_node(arg0: Graph, arg1: str, arg2: OpType, arg3: List[int], arg4: pybuda._C.DataFormat, arg5: int, arg6: Dict[str, Union[bool, int, str]]) -> int: ... -def create_output(arg0: Graph, arg1: str, arg2: List[int], arg3: pybuda._C.DataFormat, arg4: bool, arg5: int) -> int: ... -def create_parameter_input(arg0: Graph, arg1: str, arg2: List[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... -def create_target_input(arg0: Graph, arg1: str, arg2: List[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... -def eval(graph: Graph, inputs: List[object], parameters: Dict[str, object], tt_device: object, relative_atol: float, pcc: float, intermediate_golden_tensors: Dict[int, object] = ..., losses: List[object] = ..., targets: List[object] = ..., balancer_solution=..., dump_tensors_path: str = ..., allow_modified_shapes: bool = ...) -> Tuple[List[object], Dict[str, object], List[object], Dict[str, object]]: ... +def create_data_edge(arg0: Graph, arg1: int, arg2: int, arg3: int, arg4: int, arg5: list[tuple]) -> None: ... +def create_op_node(arg0: Graph, arg1: str, arg2: OpType, arg3: list[int], arg4: pybuda._C.DataFormat, arg5: int, arg6: dict[str, bool | int | str]) -> int: ... +def create_output(arg0: Graph, arg1: str, arg2: list[int], arg3: pybuda._C.DataFormat, arg4: bool, arg5: int) -> int: ... +def create_parameter_input(arg0: Graph, arg1: str, arg2: list[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... +def create_target_input(arg0: Graph, arg1: str, arg2: list[int], arg3: bool, arg4: pybuda._C.DataFormat, arg5: int) -> int: ... +def eval(graph: Graph, inputs: list[object], parameters: dict[str, object], tt_device: object, relative_atol: float, pcc: float, intermediate_golden_tensors: dict[int, object] = ..., losses: list[object] = ..., targets: list[object] = ..., dump_tensors_path: str = ..., allow_modified_shapes: bool = ...) -> tuple[list[object], dict[str, object], list[object], dict[str, object]]: ... def get_constant_input_value(arg0: Node, arg1: bool) -> object: ... -def get_intermediate_tensors(graph: Graph, inputs: List[object], parameters: Dict[str, object], tt_device: object, relative_atol: float, pcc: float, intermediate_golden_tensors: Dict[int, object] = ..., losses: List[object] = ..., targets: List[object] = ..., balancer_solution=..., dump_tensors_path: str = ..., allow_modified_shapes: bool = ...) -> Dict[str, object]: ... -def get_optimizer_param_info(arg0: Graph, arg1: str) -> List[Tuple[InputNode, str]]: ... -def get_shape_for_node(arg0: Graph, arg1: str) -> List[int]: ... -def record_consteval_operations(arg0: Graph) -> Dict[str, Optional[json]]: ... +def get_intermediate_tensors(graph: Graph, inputs: list[object], parameters: dict[str, object], tt_device: object, relative_atol: float, pcc: float, intermediate_golden_tensors: dict[int, object] = ..., losses: list[object] = ..., targets: list[object] = ..., dump_tensors_path: str = ..., allow_modified_shapes: bool = ...) -> dict[str, object]: ... +def get_optimizer_param_info(arg0: Graph, arg1: str) -> list[tuple[InputNode, str]]: ... +def get_shape_for_node(arg0: Graph, arg1: str) -> list[int]: ... +def record_consteval_operations(arg0: Graph) -> dict[str, json | None]: ... def remove_node(arg0: Graph, arg1: int) -> None: ... diff --git a/pybuda/pybuda/_C/torch_device.pyi b/pybuda/pybuda/_C/torch_device.pyi index 3aa06a5ed..b1e9ac4b7 100644 --- a/pybuda/pybuda/_C/torch_device.pyi +++ b/pybuda/pybuda/_C/torch_device.pyi @@ -1,59 +1,48 @@ -import pybuda._C.backend_api -import pybuda._C.balancer +import pybuda._C import torch -from typing import Dict, List, Optional - -class CompileRequest: - def __init__(self, netlist_path: str, output_dir: str, backend_config: pybuda._C.backend_api.BackendConfig, inputs: List[PyBudaTensorDesc], input_runtime_transforms: List[str], constants: List[PyBudaTensorDesc], parameters: List[PyBudaTensorDesc], outputs: List[PyBudaTensorDesc], output_runtime_transforms: List[str]) -> None: ... - -class Program: - def __init__(self, name: str, params: Dict[str, str]) -> None: ... class PyBudaTensorDesc: - def __init__(self, name: str, shape: List[int], ptr: int = ..., constant: Optional[torch.Tensor] = ...) -> None: ... + def __init__(self, name: str, shape: list[int], ptr: int = ..., constant: torch.Tensor | None = ...) -> None: ... @property - def constant(self) -> Optional[torch.Tensor]: ... + def constant(self) -> torch.Tensor | None: ... @property def name(self) -> str: ... @property def ptr(self) -> int: ... @property - def shape(self) -> List[int]: ... + def shape(self) -> list[int]: ... class TTDevice: def __init__(self, *args, **kwargs) -> None: ... - def compile(self, arg0: CompileRequest) -> Workload: ... - def dispatch(self, arg0: Workload, arg1: List[Program], arg2: List[torch.Tensor], arg3: Dict[str, pybuda._C.balancer.OutputHostTM]) -> List[torch.Tensor]: ... + def dispatch(self, arg0: Workload, arg1: int, arg2: list[torch.Tensor], arg3: bool) -> list[torch.Tensor]: ... def str(self) -> str: ... def torch_device(self) -> torch.device: ... @property - def arch(self) -> pybuda._C.backend_api.BackendDevice: ... + def arch(self) -> pybuda._C.Arch: ... @property def cluster_yaml(self) -> str: ... @property - def index(self) -> int: ... + def input_runtime_transforms(self) -> dict[int, list[str]]: ... @property - def mmio(self) -> bool: ... + def input_tile_bcast_dims(self) -> dict[int, list[list[int]]]: ... @property - def soc_desc_yaml(self) -> str: ... + def mmio(self) -> bool: ... @property - def type(self) -> pybuda._C.backend_api.BackendType: ... + def output_runtime_transforms(self) -> dict[int, list[str]]: ... class Workload: def __init__(self, *args, **kwargs) -> None: ... @property - def backend(self) -> pybuda._C.backend_api.BackendApi: ... - @property - def constants(self) -> List[PyBudaTensorDesc]: ... + def constants(self) -> list[PyBudaTensorDesc]: ... @property - def inputs(self) -> List[PyBudaTensorDesc]: ... + def inputs(self) -> dict[int, list[PyBudaTensorDesc]]: ... @property - def outputs(self) -> List[PyBudaTensorDesc]: ... + def outputs(self) -> dict[int, list[PyBudaTensorDesc]]: ... @property - def parameters(self) -> List[PyBudaTensorDesc]: ... + def parameters(self) -> list[PyBudaTensorDesc]: ... def get_available_devices(*args, **kwargs): ... def get_default_device(*args, **kwargs): ... def is_created_on_device(arg0: torch.Tensor) -> bool: ... -def original_shape(arg0: torch.Tensor) -> List[int]: ... -def push_tensor(arg0: pybuda._C.backend_api.BackendApi, arg1: PyBudaTensorDesc, arg2: torch.Tensor, arg3: str) -> None: ... +def original_shape(arg0: torch.Tensor) -> list[int]: ... +def unique_id(arg0: torch.Tensor) -> int: ... diff --git a/pybuda/pybuda/compile.py b/pybuda/pybuda/compile.py index f1952fad8..5d806f8b2 100644 --- a/pybuda/pybuda/compile.py +++ b/pybuda/pybuda/compile.py @@ -2,16 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import os -from sys import intern from typing import Optional, List, Dict, Any, Tuple, Union -from enum import Enum from dataclasses import dataclass, field import torch import tensorflow as tf from loguru import logger -from .tensor import Tensor +import pybuda +from pybuda.compiled_graph_state import CompiledGraphState, CompiledModel, CompileResults +from pybuda.config import ( + CompilerConfig, + CompileDepth, + _get_global_compiler_config, +) from pybuda._C import ( link_past_cache_ios, move_index_to_mm_weights, @@ -20,27 +24,20 @@ run_post_optimize_decompose_graph_passes, run_consteval_graph_pass, run_post_autograd_graph_passes, - run_lower_to_mlir_passes, + run_pre_lowering_passes, dump_graph, ) -import pybuda -from .parameter import Parameter import pybuda._C.autograd as pyautograd +import pybuda._C.graph as pygraph from pybuda._C.graph import Graph +import pybuda.ci as ci +from pybuda.module import PyBudaModule, wrap_module +from pybuda.parameter import Parameter +from pybuda.pybudaglobal import state_changed, clear_state_changed import pybuda.query as query -from .verify import VerifyConfig, do_verify, verify_golden, _generate_random_losses, _run_pytorch_backward, get_intermediate_tensors -import pybuda._C.graph as pygraph -from .config import ( - CompilerConfig, - CompileDepth, - _get_global_compiler_config, -) -from .pybudaglobal import state_changed, clear_state_changed -from pybuda import PyBudaModule -from pybuda.module import wrap_module -from .tensor import Tensor, to_pt_tensors, to_buda_tensors -from . import ci, utils -from pybuda.tools.net2reportify import net2placement +from pybuda.tensor import Tensor, to_pt_tensors +from pybuda.verify import VerifyConfig, do_verify, _generate_random_losses, _run_pytorch_backward + LAST_SUCCESSFUL_STAGE = None def init_log_last_successful_compile_stage(): @@ -95,22 +92,9 @@ def generate_override_config(graph, balancer_solution, placer_solution, nop_inst with open(path, "w") as fd: yaml.dump(overrides, fd, indent=2) -class CompileResults: - """ - Wrapper for result from the graph compiler. Contains initial and final graphs, output tensors, - and, optionally golden results for final output and intermediates, if desired. - """ - outputs: List[Tensor] - golden_outputs: List[torch.Tensor] - golden_intermediates: Dict[str, torch.Tensor] - initial_graph: Graph - final_graph: Graph - - pass_specific_output_kwargs: Dict[str, Any] = {} - @dataclass class CompileContext: - modules: List[torch.nn.Module] + modules: List[PyBudaModule] graph_name: str compiler_cfg: CompilerConfig verify_cfg: VerifyConfig @@ -215,7 +199,7 @@ def compile_main( return pybuda_compile_from_context(compile_context) -def pybuda_compile_from_context(context: CompileContext) -> CompileResults: +def pybuda_compile_from_context(context: CompileContext) -> CompiledModel: """ Run front-end compile passes and generate a Buda netlist, with a given compile context. @@ -241,6 +225,7 @@ def pybuda_compile_from_context(context: CompileContext) -> CompileResults: CompileDepth.AUTOGRAD: run_autograd_pass, CompileDepth.POST_AUTOGRAD_PASS: run_post_autograd_pass, CompileDepth.PRE_LOWERING_PASS: run_pre_lowering_pass, + CompileDepth.RUN_MLIR_COMPILER: run_mlir_compiler, CompileDepth.FINISH_COMPILE: finish_compile, } @@ -271,7 +256,25 @@ def pybuda_compile_from_context(context: CompileContext) -> CompileResults: context.stage = next_stage - return generate_compile_results(context.verify_cfg, context.initial_graph_copy, context.outputs, context.intermediate_tensors, context.final_graph, pass_specific_output_kwargs=context.output_kwargs) + compile_results = generate_compile_results( + verify_cfg, + context.initial_graph_copy, context.outputs, + context.intermediate_tensors, + final_graph=context.final_graph, + pass_specific_output_kwargs = context.output_kwargs + ) + + compiled_graph_state = CompiledGraphState.from_compiled_graph(context.modules[0], compile_results) + + compiled_module = CompiledModel( + compiled_graph_state, + context.output_kwargs["binary"] + ) + + logger.info("Compilation completed.") + + return compiled_module + def pybuda_compile_torch( module_name: str, @@ -776,10 +779,18 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth: graph_name = context.graph_name graph = context.graph - run_lower_to_mlir_passes(graph) + graph = run_pre_lowering_passes(graph) dump_graph(graph, graph_name, "pre_lowering") context.final_graph = graph + return CompileDepth.RUN_MLIR_COMPILER + +def run_mlir_compiler(context: CompileContext) -> CompileDepth: + graph = context.graph + + binary = pybuda._C.run_mlir_compiler(graph) + context.output_kwargs["binary"] = binary + return CompileDepth.FINISH_COMPILE @@ -799,13 +810,6 @@ def finish_compile(context: CompileContext) -> CompileDepth: verify_cfg = context.verify_cfg context.output_kwargs["consteval_trace"] = pygraph.record_consteval_operations(context.final_graph) - compile_results = generate_compile_results( - verify_cfg, - context.initial_graph_copy, context.outputs, - context.intermediate_tensors, - final_graph=context.final_graph, - pass_specific_output_kwargs = context.output_kwargs - ) return CompileDepth.FULL diff --git a/pybuda/pybuda/compiled_graph_state.py b/pybuda/pybuda/compiled_graph_state.py index dd1f73b4c..18ef7f76c 100644 --- a/pybuda/pybuda/compiled_graph_state.py +++ b/pybuda/pybuda/compiled_graph_state.py @@ -3,27 +3,21 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Dict, List, Any, Tuple, Optional -from dataclasses import dataclass, field -from enum import Enum -import inspect -import os -import json - -import importlib +from loguru import logger -from pybuda.compile import CompileResults - -from pybuda._C import DataFormat -from pybuda._C.graph import Graph, get_constant_input_value, get_optimizer_param_info, RuntimeTensorTransform, RuntimeTensorTransformType, Shape - -import dataclasses +from dataclasses import dataclass, field from dataclasses_json import dataclass_json, config -from pybuda.utils import as_json, dict_as_json, list_as_json, detach_tensors -from pybuda.tensor import get_device_constant_and_parameters, get_post_const_eval_tensors +from pybuda._C import DataFormat +from pybuda._C.graph import Graph, RuntimeTensorTransform +from pybuda._C.runtime import run_binary, Binary +from pybuda.utils import list_as_json +from pybuda.tensor import Tensor, get_post_const_eval_tensors from pybuda.module import Module + import torch + def no_encoding(obj): return obj # perform json-encoding later def no_decoding(obj): @@ -33,6 +27,19 @@ def optional_no_encoding(obj): def optional_no_decoding(obj): return None if obj is None else obj +class CompileResults: + """ + Wrapper for result from the graph compiler. Contains initial and final graphs, output tensors, + and, optionally golden results for final output and intermediates, if desired. + """ + outputs: List[Tensor] + golden_outputs: List[torch.Tensor] + golden_intermediates: Dict[str, torch.Tensor] + initial_graph: Graph + final_graph: Graph + + pass_specific_output_kwargs: Dict[str, Any] = {} + @dataclass_json @dataclass() class CompiledGraphState: @@ -267,3 +274,32 @@ def get_ordered_output_shapes_for_subgraph(self, subgraph_idx): def get_ordered_output_runtime_transforms_for_subgraph(self, subgraph_idx): return [transform for i, transform in enumerate(self.ordered_output_runtime_tensor_transforms) if self.ordered_output_subgraph_indices[i] == subgraph_idx] + +class CompiledModel: + """ + Callable object for running inference on the compiled model. + """ + compiled_graph_state: CompiledGraphState + binary: Binary + + def __init__(self, compiled_graph_state: CompiledGraphState, binary: Binary): + self.compiled_graph_state = compiled_graph_state + self.binary = binary + + def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: + """ + Run inference on the compiled model. + + Parameters + ---------- + inputs: Tuple[Tensor, ...] + Input tensors + + Returns + ------- + List[Tensor] + Output tensors + """ + logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...") + return run_binary(self.binary, 0, [*inputs]) + diff --git a/pybuda/pybuda/config.py b/pybuda/pybuda/config.py index e4d4213b5..8ac9077f7 100644 --- a/pybuda/pybuda/config.py +++ b/pybuda/pybuda/config.py @@ -26,8 +26,9 @@ class CompileDepth(Enum): AUTOGRAD = 6 POST_AUTOGRAD_PASS = 7 PRE_LOWERING_PASS = 8 - FINISH_COMPILE = 9 - FULL = 10 + RUN_MLIR_COMPILER = 9 + FINISH_COMPILE = 10 + FULL = 11 @classmethod def has_value(cls, value): diff --git a/pybuda/test/mlir/test_ops.py b/pybuda/test/mlir/test_ops.py index 610b8e952..41b1d3afe 100644 --- a/pybuda/test/mlir/test_ops.py +++ b/pybuda/test/mlir/test_ops.py @@ -3,6 +3,7 @@ import torch from torch import nn +import pybuda def test_add(): class Add(nn.Module): @@ -17,7 +18,7 @@ def forward(self, a, b): framework_model = Add() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model, backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] @@ -37,7 +38,7 @@ def forward(self, a, b): framework_model = Subtract() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model, backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] @@ -57,7 +58,7 @@ def forward(self, a, b): framework_model = Multiply() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model, backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] @@ -78,7 +79,7 @@ def forward(self, a): framework_model = ReLU() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model, backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] @@ -99,8 +100,8 @@ def forward(self, a): framework_model = Linear() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model.to("tt"), backend="tt") - co_out = compiled_model(*[i.to("tt") for i in inputs]) + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] @@ -120,7 +121,7 @@ def forward(self, a): framework_model = Softmax() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model, backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] diff --git a/pybuda/test/test_api.py b/pybuda/test/test_api.py index 1fea97b36..f181f2b32 100644 --- a/pybuda/test/test_api.py +++ b/pybuda/test/test_api.py @@ -17,12 +17,19 @@ def forward(self, x1, x2): return torch.add(x1, x2) model = Add() - inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32)] + shape = (1, 1024, 32) + inputs = [torch.rand(shape), torch.rand(shape)] - compiled_model = pybuda.compile(model, sample_inputs=[torch.rand(1, 32, 32), torch.rand(1, 32, 32)]) + golden = model(*inputs) - # TODO: Run inference on the compiled model, in the following way: - # compiled_model(*inputs) + compiled_model = pybuda.compile(model, sample_inputs=[torch.rand(shape), torch.rand(shape)]) + + output = compiled_model(*inputs) + + print(f"golden: {golden}") + print(f"output: {output}") + if not torch.allclose(output[0], golden, rtol=1e-1): + raise ValueError("Output does not match the golden output") def test_tf(): class TFAdd(tf.keras.Model): @@ -33,8 +40,18 @@ def call(self, x1, x2): return x1 + x2 model = TFAdd() - inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32)] - pybuda.compile(model, sample_inputs=[torch.rand(1, 32, 32), torch.rand(1, 32, 32)]) + shape = (1, 1024, 32) + inputs = [torch.rand(shape), torch.rand(shape)] + + inputs_tf = [tf.convert_to_tensor(x) for x in inputs] + golden = model(inputs_tf[0], inputs_tf[1]) + golden = torch.tensor(golden.numpy()) + + compiled_model = pybuda.compile(model, sample_inputs=[torch.rand(shape), torch.rand(shape)]) + + output = compiled_model(*inputs) - # TODO: Run inference on the compiled model, in the following way: - # compiled_model(*inputs) + print(f"golden: {golden}") + print(f"output: {output}") + if not torch.allclose(output[0], golden, rtol=1e-1): + raise ValueError("Output does not match the golden output") diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 02ee9fd83..7345e481e 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 02ee9fd83f7d4472b7e4e3c2ade9d285ca65d905 +Subproject commit 7345e481e4a0c503e81780c2ee5094242e1bc4fd diff --git a/utils/signal_handlers.hpp b/utils/signal_handlers.hpp index a4cf4d5d3..5320cd0ca 100644 --- a/utils/signal_handlers.hpp +++ b/utils/signal_handlers.hpp @@ -8,7 +8,7 @@ #include #include "utils/assert.hpp" -#include "tt_torch_device/tt_device.hpp" +#include "runtime/tt_device.hpp" inline void pybuda_signal_handler(int sig) { @@ -60,7 +60,7 @@ inline void pybuda_signal_handler(int sig) std::cerr << prefix << frame << std::endl; } - tt::close_devices(); + tt::TTSystem::get_system().close_devices(); // Restore the default signal handler and raise the signal again. // The default signal handler will generate a core dump (if enabled).