Skip to content

Commit

Permalink
Removed z_dim_slice as it isn't used.
Browse files Browse the repository at this point in the history
Improved tests, fixed bug in tests.

Cleaned up training test.
  • Loading branch information
vladimirjovanovicTT committed Aug 23, 2024
1 parent b61a0ad commit e3e033f
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 37 deletions.
11 changes: 8 additions & 3 deletions pybuda/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ void run_optimization_graph_passes(graphlib::Graph *graph)
if (not skip_erase_redundant) {
if (not attempt_update)
attempt_update = passes::erase_consecutive_reshape(graph, true);
if (not attempt_update)
attempt_update = passes::fuse_tm_sequences(graph);

// TODO: Figure out if this is needed. (Issue #152)
// if (not attempt_update)
// attempt_update = passes::fuse_tm_sequences(graph);

passes::bypass_nop_tms(graph);
}
}
Expand All @@ -167,7 +170,9 @@ void run_optimization_graph_passes(graphlib::Graph *graph)

passes::move_select_after_matmul_optional(graph);

passes::fuse_tm_sequences(graph);
// Issue #152
// passes::fuse_tm_sequences(graph);

reportify::dump_graph(graph->name(), "post_erase_inverse_ops", graph);
}

Expand Down
1 change: 1 addition & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ class MLIRGenerator
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
// lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqrtOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
}
};
}
Expand Down
16 changes: 2 additions & 14 deletions pybuda/pybuda/op/eval/pybuda/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

class TransposeTM(PyTM):
@classmethod
def create(cls, dim0, dim1, z_dim_slice=-1):
def create(cls, dim0, dim1):
self = cls("transpose")
self.dim0 = dim0
self.dim1 = dim1
self.z_dim_slice = z_dim_slice
return self

def eval(self, tensors):
Expand All @@ -28,19 +27,8 @@ def shape(self, tensor_shapes):

def backward(self, ac, operand, inputs, output, grad):
assert operand == 0, "Invalid operand index"
z_dim_slice = self.z_dim_slice
if (self.dim0 == -3 and self.dim1 == -4) or (
self.dim0 == -4 and self.dim1 == -3
):
z_dim_slice = -1
elif self.dim0 == -3 or self.dim0 == -4:
z_dim_slice = grad.shape[self.dim1]
elif self.dim1 == -3 or self.dim1 == -4:
z_dim_slice = grad.shape[self.dim0]
else:
z_dim_slice = -1
return ac.op(
TransposeTM.create(self.dim0, self.dim1, z_dim_slice=z_dim_slice),
TransposeTM.create(self.dim0, self.dim1),
[grad],
)

Expand Down
2 changes: 1 addition & 1 deletion pybuda/pybuda/op/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def Transpose(
if dim0 > dim1:
dim0, dim1 = dim1, dim0

return op("transpose", name, operandA, dim0=dim0, dim1=dim1, z_dim_slice=z_dim_slice).get_tensor(out_df=pytorch_dtype_to_buda_dataformat(out_dtype))
return op("transpose", name, operandA, attrs=(dim0, dim1, z_dim_slice), dim0=dim0, dim1=dim1).get_tensor(out_df=pytorch_dtype_to_buda_dataformat(out_dtype))

def Reshape(
name: str,
Expand Down
42 changes: 39 additions & 3 deletions pybuda/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,49 @@ def __init__(self):
def forward(self, a, b):
return a + b

inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32)]
inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)]

framework_model = Add()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]


@pytest.mark.parametrize("params", [
((1, 32, 64), (-1, -2)),
((1, 64, 32), (1, 2)),
((1, 32, 64, 128), (3, 2)),
((32, 128), (0, 1)),
((18, 65), (1, 0)),
((6, 33, 34), (-1, 1))
])
def test_transpose(params):
class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims

def forward(self, a):
return torch.transpose(a, *self.dims)

input_shape, dims = params
inputs = [torch.rand(input_shape)]

framework_model = Transpose(dims)
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]

def test_subtract():
class Subtract(nn.Module):
def __init__(self):
Expand All @@ -49,6 +80,7 @@ def forward(self, a, b):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]


Expand All @@ -69,6 +101,7 @@ def forward(self, a, b):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]


Expand All @@ -90,6 +123,7 @@ def forward(self, a):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]

@pytest.mark.skip(reason="This is not ready yet")
Expand All @@ -111,6 +145,7 @@ def forward(self, a):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]


Expand All @@ -132,6 +167,7 @@ def forward(self, a):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]

@pytest.mark.parametrize("input_shape", [(1,32,32), (1,64,64), (1,128,128,128)], ids=["32","64","128"])
Expand Down
29 changes: 14 additions & 15 deletions pybuda/test/mlir/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,51 @@

import pybuda
import pybuda.config
from pybuda.op.eval.common import compare_with_golden_pcc

def test_torch_training():
class MultParam(nn.Module):
class MatmulParam(nn.Module):
def __init__(self):
super().__init__()
self.p = nn.Parameter(torch.rand(1, 1024))
self.p = nn.Parameter(torch.rand(1024, 1024))
nn.init.xavier_uniform_(self.p)

def forward(self, x1):
return torch.multiply(x1, self.p)
def forward(self, x):
return torch.matmul(x, self.p)

model = MultParam()
model = MatmulParam()
shape = (1, 1024)
inputs = torch.rand(shape)

# Fake targets
target = torch.zeros(shape)

loss_fn = torch.nn.L1Loss(reduction='sum')
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

tt_model = pybuda.compile(model, sample_inputs=[torch.rand(shape)], loss=loss_fn, optimizer=optimizer)

num_epochs = 100
num_epochs = 20

model.train()
for epoch in range(num_epochs):

print(f"parameter value: {model.p.data}")
golden = model(inputs)
output = tt_model(inputs)

if not torch.allclose(output[0], golden, rtol=1e-1):
raise ValueError("Output does not match the golden output")
output = [co.to("cpu") for co in output]
assert compare_with_golden_pcc(golden=golden, calculated=output[0], pcc=0.99)

optimizer.zero_grad()

loss = loss_fn(output[0], target)
loss.backward()

golden_loss = loss_fn(golden, target)
print(f"epoch: {epoch} loss: {loss}")
print(f"epoch: {epoch} golden_loss: {golden_loss}")
print(f"output.grad: {output[0].grad}")

loss_grad = output[0].grad
assert loss_grad is not None

print(f"loss grad: {loss_grad}")
grad = tt_model.backward(loss_grad)

# HACK to run the optimizer step
Expand All @@ -60,4 +60,3 @@ def forward(self, x1):
model.p.grad = grad[0]

optimizer.step()

2 changes: 1 addition & 1 deletion third_party/tt-mlir
Submodule tt-mlir updated 136 files

0 comments on commit e3e033f

Please sign in to comment.