From 89cae61b2c24f28f0007c29462e136c388a7e38f Mon Sep 17 00:00:00 2001 From: Predrag Ilkic Date: Fri, 16 Aug 2024 11:57:06 +0200 Subject: [PATCH] python typing and comments --- pybuda/pybuda/compile.py | 34 ++++++++++++++++++++++----- pybuda/pybuda/compiled_graph_state.py | 4 ++-- pybuda/test/mlir/test_training.py | 8 ++----- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/pybuda/pybuda/compile.py b/pybuda/pybuda/compile.py index f814c3ecf..2d18c6232 100644 --- a/pybuda/pybuda/compile.py +++ b/pybuda/pybuda/compile.py @@ -32,7 +32,7 @@ from pybuda._C.graph import Graph from pybuda._C.runtime import Binary import pybuda.ci as ci -from pybuda.module import PyBudaModule, wrap_module +from pybuda.module import Module, PyBudaModule, wrap_module from pybuda.parameter import Parameter from pybuda.pybudaglobal import state_changed, clear_state_changed import pybuda.query as query @@ -95,14 +95,14 @@ def generate_override_config(graph, balancer_solution, placer_solution, nop_inst @dataclass class CompileContext: - modules: List[PyBudaModule] + modules: List[Module] graph_name: str compiler_cfg: CompilerConfig verify_cfg: VerifyConfig microbatch_size: int microbatch_count: int - inputs: Optional[Tuple[Union[Tensor, List[Any], Dict[str, Any]],...]] = None - loss_module: Optional[PyBudaModule] = None + inputs: Union[torch.Tensor, List[torch.Tensor]] + loss_module: Optional[Module] = None optimizer: Optional[torch.optim.Optimizer] = None training: bool = False graph: Optional[Graph] = None @@ -164,13 +164,35 @@ def calculate_grads( def compile_main( module: torch.nn.Module | tf.keras.Model | PyBudaModule, - sample_inputs: Optional[Tuple[Union[Tensor, List[Any], Dict[str, Any]],...]] = None, + sample_inputs: List[torch.Tensor], module_name: Optional[str] = None, loss: Optional[torch.nn.Module | PyBudaModule] = None, optimizer: Optional[torch.optim.Optimizer] = None, -): +) -> CompiledModel: """ Main entry point for compiling modules from different frameworks for Tenstorrent devices. + + Parameters + ---------- + module: torch.nn.Module | tf.keras.Model | PyBudaModule + Torch, TensorFlow, or PyBuda module to compile + + sample_inputs: List[torch.Tensor] + List of sample inputs for the module (used to infer shapes) + + module_name: Optional[str] + Name of the module. If not provided, the class name of the provided module will be used. + + loss: Optional[torch.nn.Module | PyBudaModule] + Loss module for training. + + optimizer: Optional[torch.optim.Optimizer] + Optimizer for training. + + Returns + ------- + CompiledModel - Callable object that can be used to run the compiled module on device. + """ assert isinstance(module, torch.nn.Module) or isinstance(module, tf.keras.Model) or isinstance(module, PyBudaModule), "Only PyTorch, TensorFlow, and PyBuda modules are supported." diff --git a/pybuda/pybuda/compiled_graph_state.py b/pybuda/pybuda/compiled_graph_state.py index 3dbf0d32f..babb2adf0 100644 --- a/pybuda/pybuda/compiled_graph_state.py +++ b/pybuda/pybuda/compiled_graph_state.py @@ -307,7 +307,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: Parameters ---------- - inputs: Tuple[Tensor, ...] + inputs: [Tensor, ...] Input tensors Returns @@ -329,7 +329,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: return outputs def forward(self, *inputs: torch.Tensor) -> List[torch.Tensor]: - return self(*inputs) + return self(inputs) def backward(self, loss_grad: torch.Tensor) -> List[torch.Tensor]: assert self.compiled_graph_state.graph.training(), "Model not compiled for training." diff --git a/pybuda/test/mlir/test_training.py b/pybuda/test/mlir/test_training.py index c1781f198..81398052b 100644 --- a/pybuda/test/mlir/test_training.py +++ b/pybuda/test/mlir/test_training.py @@ -2,13 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 -import os -import pytest import torch import torch.nn as nn -import tensorflow as tf - import pybuda import pybuda.config @@ -52,9 +48,9 @@ def forward(self, x1): print(f"epoch: {epoch} loss: {loss}") print(f"output.grad: {output[0].grad}") - golden_loss = loss_fn(golden, target) - loss_grad = output[0].grad + assert loss_grad is not None + print(f"loss grad: {loss_grad}") grad = tt_model.backward(loss_grad)