Skip to content

Commit

Permalink
Adding changes to support MNIST e2e inference
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Aug 9, 2024
1 parent 557ed8b commit 7808c2c
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 33 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ endif()

add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-unused-parameter)

if (CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
add_compile_options(-DDEBUG)
endif()

set(TTFORGE_CSRC_WARNINGS -Wall -Wextra -Wno-pragmas -Wno-unused-parameter)
set(CFLAGS_NO_WARN -DFMT_HEADER_ONLY)
set(TTFORGE_CSRC_CFLAGS ${CFLAGS_NO_WARN} ${TTFORGE_CSRC_WARNINGS} -DUTILS_LOGGER_PYTHON_OSTREAM_REDIRECT=1)
Expand Down
11 changes: 10 additions & 1 deletion pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "passes/replace_incommutable_patterns.hpp"
#include "passes/set_tile_dim.hpp"
#include "passes/squeeze_to_reshape.hpp"
#include "passes/dataformat.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "utils/assert.hpp"
Expand Down Expand Up @@ -193,7 +194,9 @@ std::vector<std::pair<graphlib::NodeId, graphlib::NodeId>> run_post_autograd_gra
}

// ********** Run pre-lowering passes **********
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override)
{
passes::print_graph(graph, "PRE_MLIR");
// Recalculate shapes, and figure out implicit broadcasts that are missing
Expand Down Expand Up @@ -227,6 +230,12 @@ graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
fold_tile_broadcast_ops_into_inputs(graph);
fold_tile_broadcast_ops_into_reduce(graph);

//
// Data formats
//
// Apply user overrides
passes::configure_output_data_formats(graph, default_df_override);

return graph;
}

Expand Down
5 changes: 3 additions & 2 deletions pybuda/csrc/buda_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ std::unique_ptr<graphlib::Graph> run_pre_placer_buda_passes(
bool enable_device_tilize = false);

// Pre-lowering passes, last-minute changes before going to MLIR
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph);

graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override = {});
}
2 changes: 2 additions & 0 deletions pybuda/csrc/passes/dataformat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void configure_a_b_format_conversion(
graphlib::Graph *graph, const DeviceConfig &device_config, const std::optional<DataFormat> default_df_override);
void validate_data_formats(const graphlib::Graph *graph, const DeviceConfig& device_config);
void validate_post_placer_data_formats(const graphlib::Graph *graph, const DeviceConfig &device_config);
void configure_output_data_formats(
graphlib::Graph *graph, std::optional<DataFormat> default_df_override);

void run_dataformat_passes(
graphlib::Graph *graph,
Expand Down
57 changes: 40 additions & 17 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class MLIRGenerator
/// Construct a new MLIRGenerator object.
MLIRGenerator(mlir::MLIRContext &context) : builder_(&context)
{
tt::log_info("MLIRGenerator");
init_lowering_handler_map();
}

Expand All @@ -64,9 +63,6 @@ class MLIRGenerator
mlir::tt::SystemDescAttr::getDefault(builder_.getContext()));
builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front());
emit_mlir_function(graph);
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
graphModule_.print(llvm::outs(), printFlags);

/// Verify the module after we have finished constructing it, this will check
/// the structural properties of the IR and invoke any specific verifiers we
Expand All @@ -76,6 +72,22 @@ class MLIRGenerator
graphModule_.emitError("module verification failed.");
throw std::runtime_error("Generated MLIR module failed verification.");
}

#ifdef DEBUG
// Create a string to store the output
std::string moduleStr;
llvm::raw_string_ostream rso(moduleStr);

// Print the MLIR module
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
graphModule_.print(rso, printFlags);

rso.flush();

log_trace(LogMLIRCompiler, "MLIR module after lowering TT-Forge graph:\n{}", moduleStr);
#endif

return graphModule_;
}

Expand Down Expand Up @@ -119,7 +131,7 @@ class MLIRGenerator
} else if constexpr (std::is_same_v<T, bool>) {
return builder_.getBoolAttr(arg);
} else if constexpr (std::is_same_v<T, int>) {
return builder_.getI32IntegerAttr(arg);
return builder_.getSI32IntegerAttr(arg);
} else if constexpr (std::is_same_v<T, float>) {
return builder_.getF32FloatAttr(arg);
} else {
Expand All @@ -133,11 +145,22 @@ class MLIRGenerator
/// A function represents a set of TTForge operations that are executed to produce output results.
/// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function.
mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) {
// Assemble the function arguments (inputs)
llvm::SmallVector<mlir::Type> arguments;
for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
// Assemble the function arguments (inputs and parameters)
llvm::SmallVector<mlir::Type> argument_types;
llvm::SmallVector<graphlib::Node *> argument_nodes;

// Add the graph inputs to the argument list
for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
{
arguments.push_back(get_node_type(input));
argument_nodes.push_back(input);
argument_types.push_back(get_node_type(input));
}

// Add the graph parameters to the argument list
for(auto *parameter: graph->get_parameter_nodes())
{
argument_nodes.push_back(parameter);
argument_types.push_back(get_node_type(parameter));
}

// Assemble the function return values (outputs)
Expand All @@ -148,18 +171,18 @@ class MLIRGenerator
}

// Create the function and emit it in the MLIR module.
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(arguments), mlir::TypeRange(returns));
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(argument_types), mlir::TypeRange(returns));
auto func = builder_.create<mlir::func::FuncOp>(graphModule_.getLoc(), "main", funcType);

// Start the body of the function by creating an entry block.
mlir::Block *entryBlock = func.addEntryBlock();

// Declare function arguments in the symbol table
for(auto namedValue: llvm::zip(graph->nodes_by_type(tt::graphlib::kInput), entryBlock->getArguments()))
for(auto namedValue: llvm::zip(argument_nodes, entryBlock->getArguments()))
{
auto node = std::get<0>(namedValue);
auto arg = std::get<1>(namedValue);
declare(node, arg);
graphlib::Node* argument_node = std::get<0>(namedValue);
mlir::BlockArgument arg = std::get<1>(namedValue);
declare(argument_node, arg);
}

// Set the insertion point in the builder to the beginning of the function
Expand All @@ -177,12 +200,12 @@ class MLIRGenerator
continue;
}

log_trace(LogMLIRGenerator, "Emitting MLIR for node {}", node->name());
log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name());
tt::graphlib::OpNode *op_node = dynamic_cast<tt::graphlib::OpNode*>(node);

// Emit MLIR for the TTForge operation node
mlir::Value opValue = emit_mlir_tt_forge_operation(graph, op_node);
log_trace(LogMLIRGenerator, "Generated MLIR for node {} with value {}",
log_trace(LogMLIRCompiler, "Generated MLIR for node {} with value {}",
node->name(), covnert_mlir_value_to_string(opValue));
}
emit_mlir_return_op(graph);
Expand Down Expand Up @@ -349,7 +372,7 @@ class MLIRGenerator
case tt::DataFormat::Float32:
return builder_.getF32Type();
case tt::DataFormat::Float16_b:
return builder_.getF16Type();
return builder_.getBF16Type();
case tt::DataFormat::Float16:
return builder_.getF16Type();
default:
Expand Down
17 changes: 15 additions & 2 deletions pybuda/csrc/passes/mlir_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace tt::passes
// It's supposed to be called when there's an error during parsing of the pipeline options.
// However, I think it's wrongly implemented in the MLIR library, so it doesn't get called.
mlir::function_ref<mlir::LogicalResult(const mlir::Twine &)> err_handler = [](const mlir::Twine &location) {
log_error(LogMLIRGenerator, "Error during parsing pipeline options: {}", location.str());
log_error(LogMLIRCompiler, "Error during parsing pipeline options: {}", location.str());
return mlir::failure();
};

Expand All @@ -63,6 +63,19 @@ namespace tt::passes
throw std::runtime_error("Failed to run MLIR compiler pass pipeline.");
}

mlir_module.get().dump();
#ifdef DEBUG
// Create a string to store the output
std::string moduleStr;
llvm::raw_string_ostream rso(moduleStr);

// Print the MLIR module
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
mlir_module.get()->print(rso, printFlags);

rso.flush();

log_trace(LogMLIRCompiler, "MLIR module after running passes:\n{}", moduleStr);
#endif
}
}
6 changes: 5 additions & 1 deletion pybuda/csrc/pybuda_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ PYBIND11_MODULE(_C, m) {
py::arg("op_intermediates_to_save") = std::vector<std::string>{},
py::arg("use_interactive_placer") = true,
py::arg("enable_device_tilize") = false);
m.def("run_pre_lowering_passes", &run_pre_lowering_passes);
m.def(
"run_pre_lowering_passes",
&run_pre_lowering_passes,
py::arg("graph"),
py::arg("default_df_override") = std::optional<DataFormat>{});
m.def("run_mlir_compiler", &passes::run_mlir_compiler);

m.def(
Expand Down
2 changes: 1 addition & 1 deletion pybuda/pybuda/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,5 @@ def run_optimization_graph_passes(arg0: graph.Graph) -> None: ...
def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ...
def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ...
def run_pre_lowering_passes(graph: graph.Graph, default_df_override: DataFormat | None = ...) -> graph.Graph: ...
def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ...
4 changes: 3 additions & 1 deletion pybuda/pybuda/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth:
graph_name = context.graph_name
graph = context.graph

graph = run_pre_lowering_passes(graph)
graph = run_pre_lowering_passes(
graph,
compiler_cfg.default_df_override)
dump_graph(graph, graph_name, "pre_lowering")

context.final_graph = graph
Expand Down
9 changes: 7 additions & 2 deletions pybuda/pybuda/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def from_compiled_graph(module: Module, compile_results: CompileResults) -> "Com
constant_to_tensor,
consteval_trace,
parameter_to_tile_dims,
ordered_parameter_node_names
ordered_parameter_node_names,
False
)

return CompiledGraphState(
Expand Down Expand Up @@ -253,6 +254,9 @@ def get_constant_tensor(self, name):

def get_parameter_tensor(self, name):
return self.get_tensor(self.post_const_eval_parameters, name)

def get_ordered_parameter_tensors(self):
return [self.get_parameter_tensor(name) for name in self.ordered_parameter_node_names]

def get_ordered_input_names_for_subgraph(self, subgraph_idx):
return [name for i, name in enumerate(self.ordered_input_names) if self.ordered_input_subgraph_indices[i] == subgraph_idx]
Expand Down Expand Up @@ -301,5 +305,6 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
Output tensors
"""
logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...")
return run_binary(self.binary, 0, [*inputs])
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()]
return run_binary(self.binary, 0, inputs_and_parameters)

13 changes: 11 additions & 2 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@

# SPDX-License-Identifier: Apache-2.0

from pybuda._C import DataFormat
from pybuda.config import _get_global_compiler_config
import torch
from torch import nn

from .utils import *

import pybuda


def test_mnist_inference():
inputs = [torch.rand(1, 784)]
compiler_cfg = _get_global_compiler_config()
df = DataFormat.Float16_b
compiler_cfg.default_df_override = df
compiler_cfg.default_accumulate_df = df

inputs = [torch.rand(1, 784, dtype=torch.bfloat16)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)

compiled_model = torch.compile(framework_model.to("tt"), backend="tt")
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*[i.to("tt") for i in inputs])

co_out = [co.to("cpu") for co in co_out]
Expand Down
11 changes: 8 additions & 3 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16)
self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16))
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, output_size)
self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16))
self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16)

def forward(self, x):
x = self.l1(x)
x = x + self.b1
x = self.relu(x)
x = self.l2(x)
x = x + self.b2

return nn.functional.softmax(x, dtype=torch.bfloat16)

return nn.functional.softmax(x)


def load_tb_writer():
Expand Down
2 changes: 1 addition & 1 deletion utils/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ constexpr LoggerABI kLoggerABI = LoggerABI::CXX11;
X(TMFusion) \
X(TTDevice) \
X(TorchDevice) \
X(MLIRGenerator)
X(MLIRCompiler)

enum LogType : uint32_t
{
Expand Down

0 comments on commit 7808c2c

Please sign in to comment.