Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding changes to support MNIST e2e inference #42

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ endif()

add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-unused-parameter)

if (CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
add_compile_options(-DDEBUG)
endif()

set(TTFORGE_CSRC_WARNINGS -Wall -Wextra -Wno-pragmas -Wno-unused-parameter)
set(CFLAGS_NO_WARN -DFMT_HEADER_ONLY)
set(TTFORGE_CSRC_CFLAGS ${CFLAGS_NO_WARN} ${TTFORGE_CSRC_WARNINGS} -DUTILS_LOGGER_PYTHON_OSTREAM_REDIRECT=1)
Expand Down
11 changes: 10 additions & 1 deletion pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "passes/replace_incommutable_patterns.hpp"
#include "passes/set_tile_dim.hpp"
#include "passes/squeeze_to_reshape.hpp"
#include "passes/dataformat.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "utils/assert.hpp"
Expand Down Expand Up @@ -193,7 +194,9 @@ std::vector<std::pair<graphlib::NodeId, graphlib::NodeId>> run_post_autograd_gra
}

// ********** Run pre-lowering passes **********
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override)
{
passes::print_graph(graph, "PRE_MLIR");
// Recalculate shapes, and figure out implicit broadcasts that are missing
Expand Down Expand Up @@ -227,6 +230,12 @@ graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph)
fold_tile_broadcast_ops_into_inputs(graph);
fold_tile_broadcast_ops_into_reduce(graph);

//
// Data formats
//
// Apply user overrides
passes::configure_output_data_formats(graph, default_df_override);
nvukobratTT marked this conversation as resolved.
Show resolved Hide resolved

return graph;
}

Expand Down
5 changes: 3 additions & 2 deletions pybuda/csrc/buda_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ std::unique_ptr<graphlib::Graph> run_pre_placer_buda_passes(
bool enable_device_tilize = false);

// Pre-lowering passes, last-minute changes before going to MLIR
graphlib::Graph* run_pre_lowering_passes(graphlib::Graph *graph);

graphlib::Graph* run_pre_lowering_passes(
graphlib::Graph *graph,
const std::optional<DataFormat> default_df_override = {});
}
2 changes: 2 additions & 0 deletions pybuda/csrc/passes/dataformat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void configure_a_b_format_conversion(
graphlib::Graph *graph, const DeviceConfig &device_config, const std::optional<DataFormat> default_df_override);
void validate_data_formats(const graphlib::Graph *graph, const DeviceConfig& device_config);
void validate_post_placer_data_formats(const graphlib::Graph *graph, const DeviceConfig &device_config);
void configure_output_data_formats(
graphlib::Graph *graph, std::optional<DataFormat> default_df_override);

void run_dataformat_passes(
graphlib::Graph *graph,
Expand Down
57 changes: 40 additions & 17 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class MLIRGenerator
/// Construct a new MLIRGenerator object.
MLIRGenerator(mlir::MLIRContext &context) : builder_(&context)
{
tt::log_info("MLIRGenerator");
init_lowering_handler_map();
}

Expand All @@ -64,9 +63,6 @@ class MLIRGenerator
mlir::tt::SystemDescAttr::getDefault(builder_.getContext()));
builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front());
emit_mlir_function(graph);
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
graphModule_.print(llvm::outs(), printFlags);

/// Verify the module after we have finished constructing it, this will check
/// the structural properties of the IR and invoke any specific verifiers we
Expand All @@ -76,6 +72,22 @@ class MLIRGenerator
graphModule_.emitError("module verification failed.");
throw std::runtime_error("Generated MLIR module failed verification.");
}

#ifdef DEBUG
// Create a string to store the output
std::string moduleStr;
llvm::raw_string_ostream rso(moduleStr);

// Print the MLIR module
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
graphModule_.print(rso, printFlags);

rso.flush();

log_trace(LogMLIRCompiler, "MLIR module after lowering TT-Forge graph:\n{}", moduleStr);
#endif

return graphModule_;
}

Expand Down Expand Up @@ -119,7 +131,7 @@ class MLIRGenerator
} else if constexpr (std::is_same_v<T, bool>) {
return builder_.getBoolAttr(arg);
} else if constexpr (std::is_same_v<T, int>) {
return builder_.getI32IntegerAttr(arg);
return builder_.getSI32IntegerAttr(arg);
} else if constexpr (std::is_same_v<T, float>) {
return builder_.getF32FloatAttr(arg);
} else {
Expand All @@ -133,11 +145,22 @@ class MLIRGenerator
/// A function represents a set of TTForge operations that are executed to produce output results.
/// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function.
mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) {
// Assemble the function arguments (inputs)
llvm::SmallVector<mlir::Type> arguments;
for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
// Assemble the function arguments (inputs and parameters)
llvm::SmallVector<mlir::Type> argument_types;
llvm::SmallVector<graphlib::Node *> argument_nodes;

// Add the graph inputs to the argument list
for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
{
arguments.push_back(get_node_type(input));
argument_nodes.push_back(input);
argument_types.push_back(get_node_type(input));
}

// Add the graph parameters to the argument list
for(auto *parameter: graph->get_parameter_nodes())
{
argument_nodes.push_back(parameter);
argument_types.push_back(get_node_type(parameter));
}

// Assemble the function return values (outputs)
Expand All @@ -148,18 +171,18 @@ class MLIRGenerator
}

// Create the function and emit it in the MLIR module.
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(arguments), mlir::TypeRange(returns));
auto funcType = builder_.getType<mlir::FunctionType>(mlir::TypeRange(argument_types), mlir::TypeRange(returns));
auto func = builder_.create<mlir::func::FuncOp>(graphModule_.getLoc(), "main", funcType);

// Start the body of the function by creating an entry block.
mlir::Block *entryBlock = func.addEntryBlock();

// Declare function arguments in the symbol table
for(auto namedValue: llvm::zip(graph->nodes_by_type(tt::graphlib::kInput), entryBlock->getArguments()))
for(auto namedValue: llvm::zip(argument_nodes, entryBlock->getArguments()))
{
auto node = std::get<0>(namedValue);
auto arg = std::get<1>(namedValue);
declare(node, arg);
graphlib::Node* argument_node = std::get<0>(namedValue);
mlir::BlockArgument arg = std::get<1>(namedValue);
declare(argument_node, arg);
}

// Set the insertion point in the builder to the beginning of the function
Expand All @@ -177,12 +200,12 @@ class MLIRGenerator
continue;
}

log_trace(LogMLIRGenerator, "Emitting MLIR for node {}", node->name());
log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name());
tt::graphlib::OpNode *op_node = dynamic_cast<tt::graphlib::OpNode*>(node);

// Emit MLIR for the TTForge operation node
mlir::Value opValue = emit_mlir_tt_forge_operation(graph, op_node);
log_trace(LogMLIRGenerator, "Generated MLIR for node {} with value {}",
log_trace(LogMLIRCompiler, "Generated MLIR for node {} with value {}",
node->name(), covnert_mlir_value_to_string(opValue));
}
emit_mlir_return_op(graph);
Expand Down Expand Up @@ -349,7 +372,7 @@ class MLIRGenerator
case tt::DataFormat::Float32:
return builder_.getF32Type();
case tt::DataFormat::Float16_b:
return builder_.getF16Type();
return builder_.getBF16Type();
case tt::DataFormat::Float16:
return builder_.getF16Type();
default:
Expand Down
17 changes: 15 additions & 2 deletions pybuda/csrc/passes/mlir_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace tt::passes
// It's supposed to be called when there's an error during parsing of the pipeline options.
// However, I think it's wrongly implemented in the MLIR library, so it doesn't get called.
mlir::function_ref<mlir::LogicalResult(const mlir::Twine &)> err_handler = [](const mlir::Twine &location) {
log_error(LogMLIRGenerator, "Error during parsing pipeline options: {}", location.str());
log_error(LogMLIRCompiler, "Error during parsing pipeline options: {}", location.str());
return mlir::failure();
};

Expand All @@ -63,6 +63,19 @@ namespace tt::passes
throw std::runtime_error("Failed to run MLIR compiler pass pipeline.");
}

mlir_module.get().dump();
#ifdef DEBUG
// Create a string to store the output
std::string moduleStr;
llvm::raw_string_ostream rso(moduleStr);

// Print the MLIR module
mlir::OpPrintingFlags printFlags;
printFlags.enableDebugInfo();
mlir_module.get()->print(rso, printFlags);

rso.flush();

log_trace(LogMLIRCompiler, "MLIR module after running passes:\n{}", moduleStr);
#endif
}
}
6 changes: 5 additions & 1 deletion pybuda/csrc/pybuda_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ PYBIND11_MODULE(_C, m) {
py::arg("op_intermediates_to_save") = std::vector<std::string>{},
py::arg("use_interactive_placer") = true,
py::arg("enable_device_tilize") = false);
m.def("run_pre_lowering_passes", &run_pre_lowering_passes);
m.def(
"run_pre_lowering_passes",
&run_pre_lowering_passes,
py::arg("graph"),
py::arg("default_df_override") = std::optional<DataFormat>{});
m.def("run_mlir_compiler", &passes::run_mlir_compiler);

m.def(
Expand Down
2 changes: 1 addition & 1 deletion pybuda/pybuda/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,5 @@ def run_optimization_graph_passes(arg0: graph.Graph) -> None: ...
def run_post_autograd_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_post_initial_graph_passes(arg0: graph.Graph, arg1: object, arg2: list[tuple[list[tuple[str, list[int], list[int]]], dict[str, list[int]]]]) -> tuple[list[tuple[int, int]], dict[str, int]]: ...
def run_post_optimize_decompose_graph_passes(arg0: graph.Graph, arg1: object) -> list[tuple[int, int]]: ...
def run_pre_lowering_passes(arg0: graph.Graph) -> graph.Graph: ...
def run_pre_lowering_passes(graph: graph.Graph, default_df_override: DataFormat | None = ...) -> graph.Graph: ...
def run_pre_placer_buda_passes(graph: graph.Graph, device_config, chip_ids: list[int] = ..., op_names_dont_fuse: list[str] = ..., op_names_manual_fuse: list[str] = ..., fracture_chip_id_assignments: dict[str, int] = ..., default_df_override: DataFormat | None = ..., default_accumulate_df: DataFormat | None = ..., enable_broadcast_splitting: bool = ..., fp32_fallback: DataFormat = ..., default_math_fidelity: MathFidelity = ..., enable_auto_fusing: bool = ..., amp_level: int = ..., enable_recompute: bool = ..., output_queues_on_host: bool = ..., input_queues_on_host: bool = ..., insert_queues: list[tuple[str, str, int]] = ..., amp_properties=..., op_intermediates_to_save: list[str] = ..., use_interactive_placer: bool = ..., enable_device_tilize: bool = ...) -> graph.Graph: ...
4 changes: 3 additions & 1 deletion pybuda/pybuda/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth:
graph_name = context.graph_name
graph = context.graph

graph = run_pre_lowering_passes(graph)
graph = run_pre_lowering_passes(
graph,
compiler_cfg.default_df_override)
pilkicTT marked this conversation as resolved.
Show resolved Hide resolved
dump_graph(graph, graph_name, "pre_lowering")

context.final_graph = graph
Expand Down
9 changes: 7 additions & 2 deletions pybuda/pybuda/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def from_compiled_graph(module: Module, compile_results: CompileResults) -> "Com
constant_to_tensor,
consteval_trace,
parameter_to_tile_dims,
ordered_parameter_node_names
ordered_parameter_node_names,
False
)

return CompiledGraphState(
Expand Down Expand Up @@ -253,6 +254,9 @@ def get_constant_tensor(self, name):

def get_parameter_tensor(self, name):
return self.get_tensor(self.post_const_eval_parameters, name)

def get_ordered_parameter_tensors(self):
return [self.get_parameter_tensor(name) for name in self.ordered_parameter_node_names]

def get_ordered_input_names_for_subgraph(self, subgraph_idx):
return [name for i, name in enumerate(self.ordered_input_names) if self.ordered_input_subgraph_indices[i] == subgraph_idx]
Expand Down Expand Up @@ -301,5 +305,6 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
Output tensors
"""
logger.info(f"Running model {self.compiled_graph_state.graph_name} on device...")
return run_binary(self.binary, 0, [*inputs])
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()]
return run_binary(self.binary, 0, inputs_and_parameters)

13 changes: 11 additions & 2 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@

# SPDX-License-Identifier: Apache-2.0

from pybuda._C import DataFormat
from pybuda.config import _get_global_compiler_config
import torch
from torch import nn

from .utils import *

import pybuda


def test_mnist_inference():
inputs = [torch.rand(1, 784)]
compiler_cfg = _get_global_compiler_config()
df = DataFormat.Float16_b
compiler_cfg.default_df_override = df
compiler_cfg.default_accumulate_df = df

inputs = [torch.rand(1, 784, dtype=torch.bfloat16)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)

compiled_model = torch.compile(framework_model.to("tt"), backend="tt")
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*[i.to("tt") for i in inputs])

co_out = [co.to("cpu") for co in co_out]
Expand Down
11 changes: 8 additions & 3 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16)
self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16))
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, output_size)
self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16))
self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16)

def forward(self, x):
x = self.l1(x)
x = x + self.b1
x = self.relu(x)
x = self.l2(x)
x = x + self.b2

return nn.functional.softmax(x, dtype=torch.bfloat16)

return nn.functional.softmax(x)


def load_tb_writer():
Expand Down
2 changes: 1 addition & 1 deletion utils/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ constexpr LoggerABI kLoggerABI = LoggerABI::CXX11;
X(TMFusion) \
X(TTDevice) \
X(TorchDevice) \
X(MLIRGenerator)
X(MLIRCompiler)

enum LogType : uint32_t
{
Expand Down
Loading