Skip to content

Commit

Permalink
Bringup E2E conv2d compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Aug 28, 2024
1 parent 64c8dd6 commit 8352ec8
Show file tree
Hide file tree
Showing 8 changed files with 1,123 additions and 822 deletions.
2 changes: 2 additions & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ class MLIRGenerator
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
// lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqrtOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
}
};
}
Expand Down
4 changes: 2 additions & 2 deletions pybuda/pybuda/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pybuda._C.graph import Graph, RuntimeTensorTransform
from pybuda._C.runtime import run_binary, Binary
from pybuda.utils import list_as_json
from pybuda.tensor import Tensor, get_post_const_eval_tensors
from pybuda.tensor import Tensor, get_post_const_eval_tensors, to_pt_tensors
from pybuda.module import Module


Expand Down Expand Up @@ -319,7 +319,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:

logger.info(f"Running model {self.compiled_graph_state.graph.get_name()} on device...")
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()]
outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters)
outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), to_pt_tensors(inputs_and_parameters))

if self.compiled_graph_state.graph.training():
# For executing loss and its backward graph on CPU, we need to tell torch to compute gradients
Expand Down
8 changes: 6 additions & 2 deletions pybuda/pybuda/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,13 @@ def get_parameters(self) -> List[Parameter]:
for param in self.module.trainable_variables:
name = param.name
data = param.numpy()

if data.dtype.name == "bfloat16":
data = data.astype(np.float32)
data = torch.tensor(data).to(torch.bfloat16)
else:
data = torch.Tensor(data)
pybuda_param = Parameter(
torch.Tensor(data),
data,
requires_grad = True,
name=name)
params.append(pybuda_param)
Expand Down
15 changes: 13 additions & 2 deletions pybuda/pybuda/op/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def Conv2d(
bias: Optional[Union[Tensor, Parameter]] = None,
stride: int = 1,
padding: Union[int, str, List] = "same",
dilation: int = 1,
dilation: Union[int, List] = 1,
groups: int = 1,
channel_last: bool = False,
) -> Tensor:
Expand Down Expand Up @@ -45,6 +45,8 @@ def Conv2d(
"""
if isinstance(stride, int):
stride = [stride] * 2
if isinstance(dilation, int):
dilation = [dilation]*2

padding = conv2d_padding_to_canonical(padding, (weights.shape[2], weights.shape[3]))

Expand Down Expand Up @@ -74,7 +76,16 @@ def Conv2d(
"conv2d",
name,
*inputs,
attrs=attrs,
stride_height=stride[0],
stride_width=stride[1],
dilation_height=dilation[0],
dilation_width=dilation[1],
groups=groups,
padding_left=padding[0],
padding_right=padding[1],
padding_top=padding[2],
padding_bottom=padding[3],
channel_last=channel_last,
).get_tensor()


Expand Down
3 changes: 2 additions & 1 deletion pybuda/pybuda/op/eval/pybuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .clip import Clip
from .cumulativesum import CumulativeSum
from .argmax import Argmax
from .convolution import Conv2d

op_to_module_map = {
"add": "eltwise_binary",
Expand Down Expand Up @@ -111,7 +112,7 @@
"reduce_max": "reduce",
"grouped_reduce_avg": "reduce",

"conv2d" : "convolution",
"conv2d" : Conv2d,
"conv2d_transpose" : "convolution",
"conv3d" : "convolution",

Expand Down
Loading

0 comments on commit 8352ec8

Please sign in to comment.