diff --git a/pybuda/csrc/buda_passes.cpp b/pybuda/csrc/buda_passes.cpp index 7d71cf614..1de5270c5 100644 --- a/pybuda/csrc/buda_passes.cpp +++ b/pybuda/csrc/buda_passes.cpp @@ -44,6 +44,7 @@ #include "passes/replace_incommutable_patterns.hpp" #include "passes/set_tile_dim.hpp" #include "passes/squeeze_to_reshape.hpp" +#include "passes/dataformat.hpp" #include "python_bindings_common.hpp" #include "reportify/reportify.hpp" #include "utils/assert.hpp" @@ -193,7 +194,9 @@ std::vector> run_post_autograd_gra } // ********** Run pre-lowering passes ********** -graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph) +graphlib::Graph* run_pre_lowering_passes( + graphlib::Graph *graph, + const std::optional default_df_override) { passes::print_graph(graph, "PRE_MLIR"); // Recalculate shapes, and figure out implicit broadcasts that are missing @@ -227,6 +230,12 @@ graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph) fold_tile_broadcast_ops_into_inputs(graph); fold_tile_broadcast_ops_into_reduce(graph); + // + // Data formats + // + // Apply user overrides + passes::configure_output_data_formats(graph, default_df_override); + return graph; } diff --git a/pybuda/csrc/buda_passes.hpp b/pybuda/csrc/buda_passes.hpp index 819769d75..314de1f82 100644 --- a/pybuda/csrc/buda_passes.hpp +++ b/pybuda/csrc/buda_passes.hpp @@ -50,6 +50,7 @@ std::unique_ptr run_pre_placer_buda_passes( bool enable_device_tilize = false); // Pre-lowering passes, last-minute changes before going to MLIR -graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph); - +graphlib::Graph* run_pre_lowering_passes( + graphlib::Graph *graph, + const std::optional default_df_override = {}); } diff --git a/pybuda/csrc/passes/dataformat.hpp b/pybuda/csrc/passes/dataformat.hpp index 606d31390..6cc1e728f 100644 --- a/pybuda/csrc/passes/dataformat.hpp +++ b/pybuda/csrc/passes/dataformat.hpp @@ -39,6 +39,8 @@ void configure_a_b_format_conversion( graphlib::Graph *graph, const DeviceConfig &device_config, const std::optional default_df_override); void validate_data_formats(const graphlib::Graph *graph, const DeviceConfig& device_config); void validate_post_placer_data_formats(const graphlib::Graph *graph, const DeviceConfig &device_config); +void configure_output_data_formats( + graphlib::Graph *graph, std::optional default_df_override); void run_dataformat_passes( graphlib::Graph *graph, diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index 10bb809ec..91a6518be 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -113,11 +113,22 @@ class MLIRGenerator /// A function represents a set of TTForge operations that are executed to produce output results. /// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function. mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) { - // Assemble the function arguments (inputs) - llvm::SmallVector arguments; - for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) + // Assemble the function arguments (inputs and parameters) + llvm::SmallVector argument_types; + llvm::SmallVector argument_nodes; + + // Add the graph inputs to the argument list + for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) { - arguments.push_back(get_node_type(input)); + argument_nodes.push_back(input); + argument_types.push_back(get_node_type(input)); + } + + // Add the graph parameters to the argument list + for(auto *parameter: graph->get_parameter_nodes()) + { + argument_nodes.push_back(parameter); + argument_types.push_back(get_node_type(parameter)); } // Assemble the function return values (outputs) @@ -128,18 +139,18 @@ class MLIRGenerator } // Create the function and emit it in the MLIR module. - auto funcType = builder_.getType(mlir::TypeRange(arguments), mlir::TypeRange(returns)); + auto funcType = builder_.getType(mlir::TypeRange(argument_types), mlir::TypeRange(returns)); auto func = builder_.create(graphModule_.getLoc(), "main", funcType); // Start the body of the function by creating an entry block. mlir::Block *entryBlock = func.addEntryBlock(); // Declare function arguments in the symbol table - for(auto namedValue: llvm::zip(graph->nodes_by_type(tt::graphlib::kInput), entryBlock->getArguments())) + for(auto namedValue: llvm::zip(argument_nodes, entryBlock->getArguments())) { - auto node = std::get<0>(namedValue); - auto arg = std::get<1>(namedValue); - declare(node, arg); + graphlib::Node* argument_node = std::get<0>(namedValue); + mlir::BlockArgument arg = std::get<1>(namedValue); + declare(argument_node, arg); } // Set the insertion point in the builder to the beginning of the function @@ -228,7 +239,6 @@ class MLIRGenerator // Workaround for now, need to figure out how to handle this properly if(op_node->op_name() == "softmax") { - log_info("Softmax"); int32_t dimension = std::get(op_node->op_attrs()[0]); mlir::NamedAttribute dimension_attribute = builder_.getNamedAttr( "dimension", @@ -331,7 +341,7 @@ class MLIRGenerator case tt::DataFormat::Float32: return builder_.getF32Type(); case tt::DataFormat::Float16_b: - return builder_.getF16Type(); + return builder_.getBF16Type(); case tt::DataFormat::Float16: return builder_.getF16Type(); default: diff --git a/pybuda/csrc/pybuda_bindings.cpp b/pybuda/csrc/pybuda_bindings.cpp index 1e5718e59..a900f3279 100644 --- a/pybuda/csrc/pybuda_bindings.cpp +++ b/pybuda/csrc/pybuda_bindings.cpp @@ -184,7 +184,11 @@ PYBIND11_MODULE(_C, m) { py::arg("op_intermediates_to_save") = std::vector{}, py::arg("use_interactive_placer") = true, py::arg("enable_device_tilize") = false); - m.def("run_pre_lowering_passes", &run_pre_lowering_passes); + m.def( + "run_pre_lowering_passes", + &run_pre_lowering_passes, + py::arg("graph"), + py::arg("default_df_override") = std::optional{}); m.def("run_mlir_compiler", &passes::run_mlir_compiler); m.def( diff --git a/pybuda/pybuda/_C/__init__.pyi b/pybuda/pybuda/_C/__init__.pyi index 20a2ef06a..77398771c 100644 --- a/pybuda/pybuda/_C/__init__.pyi +++ b/pybuda/pybuda/_C/__init__.pyi @@ -188,5 +188,5 @@ def run_optimization_graph_passes(arg0: graph.Graph) -> None: ... def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ... def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ... -def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ... +def run_pre_lowering_passes(graph: graph.Graph, default_df_override: DataFormat | None = ...) -> graph.Graph: ... def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ... diff --git a/pybuda/pybuda/compile.py b/pybuda/pybuda/compile.py index 5d806f8b2..499796626 100644 --- a/pybuda/pybuda/compile.py +++ b/pybuda/pybuda/compile.py @@ -779,7 +779,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth: graph_name = context.graph_name graph = context.graph - graph = run_pre_lowering_passes(graph) + graph = run_pre_lowering_passes( + graph, + compiler_cfg.default_df_override) dump_graph(graph, graph_name, "pre_lowering") context.final_graph = graph diff --git a/pybuda/pybuda/compiled_graph_state.py b/pybuda/pybuda/compiled_graph_state.py index 18ef7f76c..5690bc50e 100644 --- a/pybuda/pybuda/compiled_graph_state.py +++ b/pybuda/pybuda/compiled_graph_state.py @@ -196,7 +196,8 @@ def from_compiled_graph(module: Module, compile_results: CompileResults) -> "Com constant_to_tensor, consteval_trace, parameter_to_tile_dims, - ordered_parameter_node_names + ordered_parameter_node_names, + False ) return CompiledGraphState( @@ -253,6 +254,9 @@ def get_constant_tensor(self, name): def get_parameter_tensor(self, name): return self.get_tensor(self.post_const_eval_parameters, name) + + def get_ordered_parameter_tensors(self): + return [self.get_parameter_tensor(name) for name in self.ordered_parameter_node_names] def get_ordered_input_names_for_subgraph(self, subgraph_idx): return [name for i, name in enumerate(self.ordered_input_names) if self.ordered_input_subgraph_indices[i] == subgraph_idx] @@ -301,5 +305,6 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: Output tensors """ logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...") - return run_binary(self.binary, 0, [*inputs]) + inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()] + return run_binary(self.binary, 0, inputs_and_parameters) diff --git a/pybuda/test/mlir/mnist/test_inference.py b/pybuda/test/mlir/mnist/test_inference.py index f63dfdfd0..b224e559a 100644 --- a/pybuda/test/mlir/mnist/test_inference.py +++ b/pybuda/test/mlir/mnist/test_inference.py @@ -2,19 +2,28 @@ # SPDX-License-Identifier: Apache-2.0 +from pybuda._C import DataFormat +from pybuda.config import _get_global_compiler_config import torch from torch import nn from .utils import * +import pybuda + def test_mnist_inference(): - inputs = [torch.rand(1, 784)] + compiler_cfg = _get_global_compiler_config() + df = DataFormat.Float16_b + compiler_cfg.default_df_override = df + compiler_cfg.default_accumulate_df = df + + inputs = [torch.rand(1, 784, dtype=torch.bfloat16)] framework_model = MNISTLinear() fw_out = framework_model(*inputs) - compiled_model = torch.compile(framework_model.to("tt"), backend="tt") + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) co_out = compiled_model(*[i.to("tt") for i in inputs]) co_out = [co.to("cpu") for co in co_out] diff --git a/pybuda/test/mlir/mnist/utils.py b/pybuda/test/mlir/mnist/utils.py index 724fbf5c2..260840e09 100644 --- a/pybuda/test/mlir/mnist/utils.py +++ b/pybuda/test/mlir/mnist/utils.py @@ -16,16 +16,21 @@ class MNISTLinear(nn.Module): def __init__(self, input_size=784, output_size=10, hidden_size=256): super(MNISTLinear, self).__init__() - self.l1 = nn.Linear(input_size, hidden_size) + self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16) + self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16)) self.relu = nn.ReLU() - self.l2 = nn.Linear(hidden_size, output_size) + self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16)) + self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16) def forward(self, x): x = self.l1(x) + x = x + self.b1 x = self.relu(x) x = self.l2(x) + x = x + self.b2 + + return nn.functional.softmax(x, dtype=torch.bfloat16) - return nn.functional.softmax(x) def load_tb_writer():