From 502a51dd5f079efb4ea25abdcec3c98529ffb23c Mon Sep 17 00:00:00 2001 From: Vladimir Jovanovic Date: Mon, 19 Aug 2024 15:04:18 +0000 Subject: [PATCH] Working on transpose/training tests. Updated training test. --- pybuda/csrc/buda_passes.cpp | 11 ++++++++--- pybuda/csrc/passes/lower_to_mlir.cpp | 1 + pybuda/pybuda/op/tm.py | 2 +- pybuda/test/mlir/test_ops.py | 19 +++++++++++++++++++ pybuda/test/mlir/test_training.py | 28 ++++++++++++++-------------- third_party/tt-mlir | 2 +- 6 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pybuda/csrc/buda_passes.cpp b/pybuda/csrc/buda_passes.cpp index 1de5270c5..7ec103f72 100644 --- a/pybuda/csrc/buda_passes.cpp +++ b/pybuda/csrc/buda_passes.cpp @@ -148,8 +148,11 @@ void run_optimization_graph_passes(graphlib::Graph *graph) if (not skip_erase_redundant) { if (not attempt_update) attempt_update = passes::erase_consecutive_reshape(graph, true); - if (not attempt_update) - attempt_update = passes::fuse_tm_sequences(graph); + + // TODO: Figure out if this is needed. (Issue #152) + // if (not attempt_update) + // attempt_update = passes::fuse_tm_sequences(graph); + passes::bypass_nop_tms(graph); } } @@ -167,7 +170,9 @@ void run_optimization_graph_passes(graphlib::Graph *graph) passes::move_select_after_matmul_optional(graph); - passes::fuse_tm_sequences(graph); + // Issue #152 + // passes::fuse_tm_sequences(graph); + reportify::dump_graph(graph->name(), "post_erase_inverse_ops", graph); } diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index aea51b2a1..e616d0c36 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -504,6 +504,7 @@ class MLIRGenerator lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op; } }; } diff --git a/pybuda/pybuda/op/tm.py b/pybuda/pybuda/op/tm.py index 5430fc759..27a04e9c0 100644 --- a/pybuda/pybuda/op/tm.py +++ b/pybuda/pybuda/op/tm.py @@ -161,7 +161,7 @@ def Transpose( if dim0 > dim1: dim0, dim1 = dim1, dim0 - return op("transpose", name, operandA, dim0=dim0, dim1=dim1, z_dim_slice=z_dim_slice).get_tensor(out_df=pytorch_dtype_to_buda_dataformat(out_dtype)) + return op("transpose", name, operandA, attrs=(dim0, dim1, z_dim_slice), dim0=dim0, dim1=dim1).get_tensor(out_df=pytorch_dtype_to_buda_dataformat(out_dtype)) def Reshape( name: str, diff --git a/pybuda/test/mlir/test_ops.py b/pybuda/test/mlir/test_ops.py index bde2a3bc8..492d37279 100644 --- a/pybuda/test/mlir/test_ops.py +++ b/pybuda/test/mlir/test_ops.py @@ -31,6 +31,25 @@ def forward(self, a, b): co_out = [co.to("cpu") for co in co_out] assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] +def test_transpose(): + class Transpose(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + c = a + b + return torch.transpose(c, 1, 2) + + inputs = [torch.rand(1, 32, 64), torch.rand(1, 32, 64)] + + framework_model = Transpose() + fw_out = framework_model(*inputs) + + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] def test_subtract(): class Subtract(nn.Module): diff --git a/pybuda/test/mlir/test_training.py b/pybuda/test/mlir/test_training.py index 81398052b..9633f43b3 100644 --- a/pybuda/test/mlir/test_training.py +++ b/pybuda/test/mlir/test_training.py @@ -7,38 +7,39 @@ import pybuda import pybuda.config +from pybuda.op.eval.common import compare_with_golden_pcc def test_torch_training(): class MultParam(nn.Module): def __init__(self): super().__init__() - self.p = nn.Parameter(torch.rand(1, 1024)) + self.p = nn.Parameter(torch.ones(1024, 1024)) + nn.init.xavier_uniform_(self.p) def forward(self, x1): - return torch.multiply(x1, self.p) + return torch.matmul(x1, self.p) model = MultParam() - shape = (1, 1024) + shape = (8, 1024) inputs = torch.rand(shape) - # Fake targets target = torch.zeros(shape) - loss_fn = torch.nn.L1Loss(reduction='sum') - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) tt_model = pybuda.compile(model, sample_inputs=[torch.rand(shape)], loss=loss_fn, optimizer=optimizer) - num_epochs = 100 + num_epochs = 20 for epoch in range(num_epochs): - - print(f"parameter value: {model.p.data}") + #print(f"parameter value: {model.p.data}") golden = model(inputs) output = tt_model(inputs) - if not torch.allclose(output[0], golden, rtol=1e-1): - raise ValueError("Output does not match the golden output") + #print(f"golden = {golden}, output = {output[0]}") + oputput = [co.to("cpu") for co in output] + assert compare_with_golden_pcc(golden=golden, calculated=output[0], pcc=0.99) optimizer.zero_grad() @@ -46,12 +47,12 @@ def forward(self, x1): loss.backward() print(f"epoch: {epoch} loss: {loss}") - print(f"output.grad: {output[0].grad}") + #print(f"output.grad: {output[0].grad}") loss_grad = output[0].grad assert loss_grad is not None - print(f"loss grad: {loss_grad}") + #print(f"loss grad: {loss_grad}") grad = tt_model.backward(loss_grad) # HACK to run the optimizer step @@ -60,4 +61,3 @@ def forward(self, x1): model.p.grad = grad[0] optimizer.step() - diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 83c705cb6..d33cd6ab1 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 83c705cb61729f9b23ad3e1c1839023eea259711 +Subproject commit d33cd6ab1e453334e4055875b7ac3669d41b7c23