Skip to content

Commit

Permalink
Adding changes to support MNIST e2e inference
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Aug 7, 2024
1 parent 7b37a74 commit 1573bd1
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 24 deletions.
11 changes: 10 additions & 1 deletion pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "passes/replace_incommutable_patterns.hpp"
#include "passes/set_tile_dim.hpp"
#include "passes/squeeze_to_reshape.hpp"
#include "passes/dataformat.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "utils/assert.hpp"
Expand Down Expand Up @@ -193,7 +194,9 @@ std::vector<std::pair<graphlib::NodeId, graphlib::NodeId>> run_post_autograd_gra
}

// ********** Run pre-lowering passes **********
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override)
{
passes::print_graph(graph, "PRE_MLIR");
// Recalculate shapes, and figure out implicit broadcasts that are missing
Expand Down Expand Up @@ -227,6 +230,12 @@ graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
fold_tile_broadcast_ops_into_inputs(graph);
fold_tile_broadcast_ops_into_reduce(graph);

//
// Data formats
//
// Apply user overrides
passes::configure_output_data_formats(graph, default_df_override);

return graph;
}

Expand Down
5 changes: 3 additions & 2 deletions pybuda/csrc/buda_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ std::unique_ptr<graphlib::Graph> run_pre_placer_buda_passes(
bool enable_device_tilize = false);

// Pre-lowering passes, last-minute changes before going to MLIR
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph);

graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override = {});
}
2 changes: 2 additions & 0 deletions pybuda/csrc/passes/dataformat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void configure_a_b_format_conversion(
graphlib::Graph *graph, const DeviceConfig &device_config, const std::optional<DataFormat> default_df_override);
void validate_data_formats(const graphlib::Graph *graph, const DeviceConfig& device_config);
void validate_post_placer_data_formats(const graphlib::Graph *graph, const DeviceConfig &device_config);
void configure_output_data_formats(
graphlib::Graph *graph, std::optional<DataFormat> default_df_override);

void run_dataformat_passes(
graphlib::Graph *graph,
Expand Down
32 changes: 21 additions & 11 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,22 @@ class MLIRGenerator
/// A function represents a set of TTForge operations that are executed to produce output results.
/// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function.
mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) {
// Assemble the function arguments (inputs)
llvm::SmallVector<mlir::Type> arguments;
for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
// Assemble the function arguments (inputs and parameters)
llvm::SmallVector<mlir::Type> argument_types;
llvm::SmallVector<graphlib::Node *> argument_nodes;

// Add the graph inputs to the argument list
for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
{
arguments.push_back(get_node_type(input));
argument_nodes.push_back(input);
argument_types.push_back(get_node_type(input));
}

// Add the graph parameters to the argument list
for(auto *parameter: graph->get_parameter_nodes())
{
argument_nodes.push_back(parameter);
argument_types.push_back(get_node_type(parameter));
}

// Assemble the function return values (outputs)
Expand All @@ -128,18 +139,18 @@ class MLIRGenerator
}

// Create the function and emit it in the MLIR module.
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(arguments), mlir::TypeRange(returns));
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(argument_types), mlir::TypeRange(returns));
auto func = builder_.create<mlir::func::FuncOp>(graphModule_.getLoc(), "main", funcType);

// Start the body of the function by creating an entry block.
mlir::Block *entryBlock = func.addEntryBlock();

// Declare function arguments in the symbol table
for(auto namedValue: llvm::zip(graph->nodes_by_type(tt::graphlib::kInput), entryBlock->getArguments()))
for(auto namedValue: llvm::zip(argument_nodes, entryBlock->getArguments()))
{
auto node = std::get<0>(namedValue);
auto arg = std::get<1>(namedValue);
declare(node, arg);
graphlib::Node* argument_node = std::get<0>(namedValue);
mlir::BlockArgument arg = std::get<1>(namedValue);
declare(argument_node, arg);
}

// Set the insertion point in the builder to the beginning of the function
Expand Down Expand Up @@ -228,7 +239,6 @@ class MLIRGenerator
// Workaround for now, need to figure out how to handle this properly
if(op_node->op_name() == "softmax")
{
log_info("Softmax");
int32_t dimension = std::get<int>(op_node->op_attrs()[0]);
mlir::NamedAttribute dimension_attribute = builder_.getNamedAttr(
"dimension",
Expand Down Expand Up @@ -331,7 +341,7 @@ class MLIRGenerator
case tt::DataFormat::Float32:
return builder_.getF32Type();
case tt::DataFormat::Float16_b:
return builder_.getF16Type();
return builder_.getBF16Type();
case tt::DataFormat::Float16:
return builder_.getF16Type();
default:
Expand Down
6 changes: 5 additions & 1 deletion pybuda/csrc/pybuda_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ PYBIND11_MODULE(_C, m) {
py::arg("op_intermediates_to_save") = std::vector<std::string>{},
py::arg("use_interactive_placer") = true,
py::arg("enable_device_tilize") = false);
m.def("run_pre_lowering_passes", &run_pre_lowering_passes);
m.def(
"run_pre_lowering_passes",
&run_pre_lowering_passes,
py::arg("graph"),
py::arg("default_df_override") = std::optional<DataFormat>{});
m.def("run_mlir_compiler", &passes::run_mlir_compiler);

m.def(
Expand Down
2 changes: 1 addition & 1 deletion pybuda/pybuda/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,5 @@ def run_optimization_graph_passes(arg0: graph.Graph) -> None: ...
def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ...
def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ...
def run_pre_lowering_passes(graph: graph.Graph, default_df_override: DataFormat | None = ...) -> graph.Graph: ...
def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ...
4 changes: 3 additions & 1 deletion pybuda/pybuda/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth:
graph_name = context.graph_name
graph = context.graph

graph = run_pre_lowering_passes(graph)
graph = run_pre_lowering_passes(
graph,
compiler_cfg.default_df_override)
dump_graph(graph, graph_name, "pre_lowering")

context.final_graph = graph
Expand Down
9 changes: 7 additions & 2 deletions pybuda/pybuda/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def from_compiled_graph(module: Module, compile_results: CompileResults) -> "Com
constant_to_tensor,
consteval_trace,
parameter_to_tile_dims,
ordered_parameter_node_names
ordered_parameter_node_names,
False
)

return CompiledGraphState(
Expand Down Expand Up @@ -253,6 +254,9 @@ def get_constant_tensor(self, name):

def get_parameter_tensor(self, name):
return self.get_tensor(self.post_const_eval_parameters, name)

def get_ordered_parameter_tensors(self):
return [self.get_parameter_tensor(name) for name in self.ordered_parameter_node_names]

def get_ordered_input_names_for_subgraph(self, subgraph_idx):
return [name for i, name in enumerate(self.ordered_input_names) if self.ordered_input_subgraph_indices[i] == subgraph_idx]
Expand Down Expand Up @@ -301,5 +305,6 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
Output tensors
"""
logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...")
return run_binary(self.binary, 0, [*inputs])
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()]
return run_binary(self.binary, 0, inputs_and_parameters)

13 changes: 11 additions & 2 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@

# SPDX-License-Identifier: Apache-2.0

from pybuda._C import DataFormat
from pybuda.config import _get_global_compiler_config
import torch
from torch import nn

from .utils import *

import pybuda


def test_mnist_inference():
inputs = [torch.rand(1, 784)]
compiler_cfg = _get_global_compiler_config()
df = DataFormat.Float16_b
compiler_cfg.default_df_override = df
compiler_cfg.default_accumulate_df = df

inputs = [torch.rand(1, 784, dtype=torch.bfloat16)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)

compiled_model = torch.compile(framework_model.to("tt"), backend="tt")
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*[i.to("tt") for i in inputs])

co_out = [co.to("cpu") for co in co_out]
Expand Down
11 changes: 8 additions & 3 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16)
self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16))
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, output_size)
self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16))
self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16)

def forward(self, x):
x = self.l1(x)
x = x + self.b1
x = self.relu(x)
x = self.l2(x)
x = x + self.b2

return nn.functional.softmax(x, dtype=torch.bfloat16)

return nn.functional.softmax(x)


def load_tb_writer():
Expand Down

0 comments on commit 1573bd1

Please sign in to comment.