From 7808c2c8d6decb716a3e7d7e1cfcbee7111e3118 Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic Date: Wed, 7 Aug 2024 15:25:03 +0000 Subject: [PATCH] Adding changes to support MNIST e2e inference --- CMakeLists.txt | 4 ++ pybuda/csrc/buda_passes.cpp | 11 ++++- pybuda/csrc/buda_passes.hpp | 5 ++- pybuda/csrc/passes/dataformat.hpp | 2 + pybuda/csrc/passes/lower_to_mlir.cpp | 57 +++++++++++++++++------- pybuda/csrc/passes/mlir_passes.cpp | 17 ++++++- pybuda/csrc/pybuda_bindings.cpp | 6 ++- pybuda/pybuda/_C/__init__.pyi | 2 +- pybuda/pybuda/compile.py | 4 +- pybuda/pybuda/compiled_graph_state.py | 9 +++- pybuda/test/mlir/mnist/test_inference.py | 13 +++++- pybuda/test/mlir/mnist/utils.py | 11 +++-- utils/logger.hpp | 2 +- 13 files changed, 110 insertions(+), 33 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f8f6ff4c..0264952df 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,10 @@ endif() add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-unused-parameter) +if (CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + add_compile_options(-DDEBUG) +endif() + set(TTFORGE_CSRC_WARNINGS -Wall -Wextra -Wno-pragmas -Wno-unused-parameter) set(CFLAGS_NO_WARN -DFMT_HEADER_ONLY) set(TTFORGE_CSRC_CFLAGS ${CFLAGS_NO_WARN} ${TTFORGE_CSRC_WARNINGS} -DUTILS_LOGGER_PYTHON_OSTREAM_REDIRECT=1) diff --git a/pybuda/csrc/buda_passes.cpp b/pybuda/csrc/buda_passes.cpp index 7d71cf614..1de5270c5 100644 --- a/pybuda/csrc/buda_passes.cpp +++ b/pybuda/csrc/buda_passes.cpp @@ -44,6 +44,7 @@ #include "passes/replace_incommutable_patterns.hpp" #include "passes/set_tile_dim.hpp" #include "passes/squeeze_to_reshape.hpp" +#include "passes/dataformat.hpp" #include "python_bindings_common.hpp" #include "reportify/reportify.hpp" #include "utils/assert.hpp" @@ -193,7 +194,9 @@ std::vector> run_post_autograd_gra } // ********** Run pre-lowering passes ********** -graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph) +graphlib::Graph* run_pre_lowering_passes( + graphlib::Graph *graph, + const std::optional default_df_override) { passes::print_graph(graph, "PRE_MLIR"); // Recalculate shapes, and figure out implicit broadcasts that are missing @@ -227,6 +230,12 @@ graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph) fold_tile_broadcast_ops_into_inputs(graph); fold_tile_broadcast_ops_into_reduce(graph); + // + // Data formats + // + // Apply user overrides + passes::configure_output_data_formats(graph, default_df_override); + return graph; } diff --git a/pybuda/csrc/buda_passes.hpp b/pybuda/csrc/buda_passes.hpp index 819769d75..314de1f82 100644 --- a/pybuda/csrc/buda_passes.hpp +++ b/pybuda/csrc/buda_passes.hpp @@ -50,6 +50,7 @@ std::unique_ptr run_pre_placer_buda_passes( bool enable_device_tilize = false); // Pre-lowering passes, last-minute changes before going to MLIR -graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph); - +graphlib::Graph* run_pre_lowering_passes( + graphlib::Graph *graph, + const std::optional default_df_override = {}); } diff --git a/pybuda/csrc/passes/dataformat.hpp b/pybuda/csrc/passes/dataformat.hpp index 606d31390..6cc1e728f 100644 --- a/pybuda/csrc/passes/dataformat.hpp +++ b/pybuda/csrc/passes/dataformat.hpp @@ -39,6 +39,8 @@ void configure_a_b_format_conversion( graphlib::Graph *graph, const DeviceConfig &device_config, const std::optional default_df_override); void validate_data_formats(const graphlib::Graph *graph, const DeviceConfig& device_config); void validate_post_placer_data_formats(const graphlib::Graph *graph, const DeviceConfig &device_config); +void configure_output_data_formats( + graphlib::Graph *graph, std::optional default_df_override); void run_dataformat_passes( graphlib::Graph *graph, diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index 72bc4171d..8bf6a4266 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -52,7 +52,6 @@ class MLIRGenerator /// Construct a new MLIRGenerator object. MLIRGenerator(mlir::MLIRContext &context) : builder_(&context) { - tt::log_info("MLIRGenerator"); init_lowering_handler_map(); } @@ -64,9 +63,6 @@ class MLIRGenerator mlir::tt::SystemDescAttr::getDefault(builder_.getContext())); builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front()); emit_mlir_function(graph); - mlir::OpPrintingFlags printFlags; - printFlags.enableDebugInfo(); - graphModule_.print(llvm::outs(), printFlags); /// Verify the module after we have finished constructing it, this will check /// the structural properties of the IR and invoke any specific verifiers we @@ -76,6 +72,22 @@ class MLIRGenerator graphModule_.emitError("module verification failed."); throw std::runtime_error("Generated MLIR module failed verification."); } + +#ifdef DEBUG + // Create a string to store the output + std::string moduleStr; + llvm::raw_string_ostream rso(moduleStr); + + // Print the MLIR module + mlir::OpPrintingFlags printFlags; + printFlags.enableDebugInfo(); + graphModule_.print(rso, printFlags); + + rso.flush(); + + log_trace(LogMLIRCompiler, "MLIR module after lowering TT-Forge graph:\n{}", moduleStr); +#endif + return graphModule_; } @@ -119,7 +131,7 @@ class MLIRGenerator } else if constexpr (std::is_same_v) { return builder_.getBoolAttr(arg); } else if constexpr (std::is_same_v) { - return builder_.getI32IntegerAttr(arg); + return builder_.getSI32IntegerAttr(arg); } else if constexpr (std::is_same_v) { return builder_.getF32FloatAttr(arg); } else { @@ -133,11 +145,22 @@ class MLIRGenerator /// A function represents a set of TTForge operations that are executed to produce output results. /// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function. mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) { - // Assemble the function arguments (inputs) - llvm::SmallVector arguments; - for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) + // Assemble the function arguments (inputs and parameters) + llvm::SmallVector argument_types; + llvm::SmallVector argument_nodes; + + // Add the graph inputs to the argument list + for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) { - arguments.push_back(get_node_type(input)); + argument_nodes.push_back(input); + argument_types.push_back(get_node_type(input)); + } + + // Add the graph parameters to the argument list + for(auto *parameter: graph->get_parameter_nodes()) + { + argument_nodes.push_back(parameter); + argument_types.push_back(get_node_type(parameter)); } // Assemble the function return values (outputs) @@ -148,18 +171,18 @@ class MLIRGenerator } // Create the function and emit it in the MLIR module. - auto funcType = builder_.getType(mlir::TypeRange(arguments), mlir::TypeRange(returns)); + auto funcType = builder_.getType(mlir::TypeRange(argument_types), mlir::TypeRange(returns)); auto func = builder_.create(graphModule_.getLoc(), "main", funcType); // Start the body of the function by creating an entry block. mlir::Block *entryBlock = func.addEntryBlock(); // Declare function arguments in the symbol table - for(auto namedValue: llvm::zip(graph->nodes_by_type(tt::graphlib::kInput), entryBlock->getArguments())) + for(auto namedValue: llvm::zip(argument_nodes, entryBlock->getArguments())) { - auto node = std::get<0>(namedValue); - auto arg = std::get<1>(namedValue); - declare(node, arg); + graphlib::Node* argument_node = std::get<0>(namedValue); + mlir::BlockArgument arg = std::get<1>(namedValue); + declare(argument_node, arg); } // Set the insertion point in the builder to the beginning of the function @@ -177,12 +200,12 @@ class MLIRGenerator continue; } - log_trace(LogMLIRGenerator, "Emitting MLIR for node {}", node->name()); + log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name()); tt::graphlib::OpNode *op_node = dynamic_cast(node); // Emit MLIR for the TTForge operation node mlir::Value opValue = emit_mlir_tt_forge_operation(graph, op_node); - log_trace(LogMLIRGenerator, "Generated MLIR for node {} with value {}", + log_trace(LogMLIRCompiler, "Generated MLIR for node {} with value {}", node->name(), covnert_mlir_value_to_string(opValue)); } emit_mlir_return_op(graph); @@ -349,7 +372,7 @@ class MLIRGenerator case tt::DataFormat::Float32: return builder_.getF32Type(); case tt::DataFormat::Float16_b: - return builder_.getF16Type(); + return builder_.getBF16Type(); case tt::DataFormat::Float16: return builder_.getF16Type(); default: diff --git a/pybuda/csrc/passes/mlir_passes.cpp b/pybuda/csrc/passes/mlir_passes.cpp index 851ef3703..5ab32ee30 100644 --- a/pybuda/csrc/passes/mlir_passes.cpp +++ b/pybuda/csrc/passes/mlir_passes.cpp @@ -44,7 +44,7 @@ namespace tt::passes // It's supposed to be called when there's an error during parsing of the pipeline options. // However, I think it's wrongly implemented in the MLIR library, so it doesn't get called. mlir::function_ref err_handler = [](const mlir::Twine &location) { - log_error(LogMLIRGenerator, "Error during parsing pipeline options: {}", location.str()); + log_error(LogMLIRCompiler, "Error during parsing pipeline options: {}", location.str()); return mlir::failure(); }; @@ -63,6 +63,19 @@ namespace tt::passes throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); } - mlir_module.get().dump(); +#ifdef DEBUG + // Create a string to store the output + std::string moduleStr; + llvm::raw_string_ostream rso(moduleStr); + + // Print the MLIR module + mlir::OpPrintingFlags printFlags; + printFlags.enableDebugInfo(); + mlir_module.get()->print(rso, printFlags); + + rso.flush(); + + log_trace(LogMLIRCompiler, "MLIR module after running passes:\n{}", moduleStr); +#endif } } diff --git a/pybuda/csrc/pybuda_bindings.cpp b/pybuda/csrc/pybuda_bindings.cpp index 1e5718e59..a900f3279 100644 --- a/pybuda/csrc/pybuda_bindings.cpp +++ b/pybuda/csrc/pybuda_bindings.cpp @@ -184,7 +184,11 @@ PYBIND11_MODULE(_C, m) { py::arg("op_intermediates_to_save") = std::vector{}, py::arg("use_interactive_placer") = true, py::arg("enable_device_tilize") = false); - m.def("run_pre_lowering_passes", &run_pre_lowering_passes); + m.def( + "run_pre_lowering_passes", + &run_pre_lowering_passes, + py::arg("graph"), + py::arg("default_df_override") = std::optional{}); m.def("run_mlir_compiler", &passes::run_mlir_compiler); m.def( diff --git a/pybuda/pybuda/_C/__init__.pyi b/pybuda/pybuda/_C/__init__.pyi index 20a2ef06a..77398771c 100644 --- a/pybuda/pybuda/_C/__init__.pyi +++ b/pybuda/pybuda/_C/__init__.pyi @@ -188,5 +188,5 @@ def run_optimization_graph_passes(arg0: graph.Graph) -> None: ... def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ... def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... -def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ... +def run_pre_lowering_passes(graph: graph.Graph, default_df_override: DataFormat | None = ...) -> graph.Graph: ... def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ... diff --git a/pybuda/pybuda/compile.py b/pybuda/pybuda/compile.py index 5d806f8b2..499796626 100644 --- a/pybuda/pybuda/compile.py +++ b/pybuda/pybuda/compile.py @@ -779,7 +779,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth: graph_name = context.graph_name graph = context.graph - graph = run_pre_lowering_passes(graph) + graph = run_pre_lowering_passes( + graph, + compiler_cfg.default_df_override) dump_graph(graph, graph_name, "pre_lowering") context.final_graph = graph diff --git a/pybuda/pybuda/compiled_graph_state.py b/pybuda/pybuda/compiled_graph_state.py index 18ef7f76c..5690bc50e 100644 --- a/pybuda/pybuda/compiled_graph_state.py +++ b/pybuda/pybuda/compiled_graph_state.py @@ -196,7 +196,8 @@ def from_compiled_graph(module: Module, compile_results: CompileResults) -> "Com constant_to_tensor, consteval_trace, parameter_to_tile_dims, - ordered_parameter_node_names + ordered_parameter_node_names, + False ) return CompiledGraphState( @@ -253,6 +254,9 @@ def get_constant_tensor(self, name): def get_parameter_tensor(self, name): return self.get_tensor(self.post_const_eval_parameters, name) + + def get_ordered_parameter_tensors(self): + return [self.get_parameter_tensor(name) for name in self.ordered_parameter_node_names] def get_ordered_input_names_for_subgraph(self, subgraph_idx): return [name for i, name in enumerate(self.ordered_input_names) if self.ordered_input_subgraph_indices[i] == subgraph_idx] @@ -301,5 +305,6 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: Output tensors """ logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...") - return run_binary(self.binary, 0, [*inputs]) + inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()] + return run_binary(self.binary, 0, inputs_and_parameters) diff --git a/pybuda/test/mlir/mnist/test_inference.py b/pybuda/test/mlir/mnist/test_inference.py index f63dfdfd0..b224e559a 100644 --- a/pybuda/test/mlir/mnist/test_inference.py +++ b/pybuda/test/mlir/mnist/test_inference.py @@ -2,19 +2,28 @@ # SPDX-License-Identifier: Apache-2.0 +from pybuda._C import DataFormat +from pybuda.config import _get_global_compiler_config import torch from torch import nn from .utils import * +import pybuda + def test_mnist_inference(): - inputs = [torch.rand(1, 784)] + compiler_cfg = _get_global_compiler_config() + df = DataFormat.Float16_b + compiler_cfg.default_df_override = df + compiler_cfg.default_accumulate_df = df + + inputs = [torch.rand(1, 784, dtype=torch.bfloat16)] framework_model = MNISTLinear() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model.to("tt"), backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*[i.to("tt") for i in inputs]) co_out = [co.to("cpu") for co in co_out] diff --git a/pybuda/test/mlir/mnist/utils.py b/pybuda/test/mlir/mnist/utils.py index 724fbf5c2..260840e09 100644 --- a/pybuda/test/mlir/mnist/utils.py +++ b/pybuda/test/mlir/mnist/utils.py @@ -16,16 +16,21 @@ class MNISTLinear(nn.Module): def __init__(self, input_size=784, output_size=10, hidden_size=256): super(MNISTLinear, self).__init__() - self.l1 = nn.Linear(input_size, hidden_size) + self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16) + self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16)) self.relu = nn.ReLU() - self.l2 = nn.Linear(hidden_size, output_size) + self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16)) + self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16) def forward(self, x): x = self.l1(x) + x = x + self.b1 x = self.relu(x) x = self.l2(x) + x = x + self.b2 + + return nn.functional.softmax(x, dtype=torch.bfloat16) - return nn.functional.softmax(x) def load_tb_writer(): diff --git a/utils/logger.hpp b/utils/logger.hpp index e3819a345..0fcdc02b0 100644 --- a/utils/logger.hpp +++ b/utils/logger.hpp @@ -65,7 +65,7 @@ constexpr LoggerABI kLoggerABI = LoggerABI::CXX11; X(TMFusion) \ X(TTDevice) \ X(TorchDevice) \ - X(MLIRGenerator) + X(MLIRCompiler) enum LogType : uint32_t {