Skip to content

Commit

Permalink
python typing and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pilkicTT committed Aug 16, 2024
1 parent 3a66759 commit 89cae61
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
34 changes: 28 additions & 6 deletions pybuda/pybuda/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pybuda._C.graph import Graph
from pybuda._C.runtime import Binary
import pybuda.ci as ci
from pybuda.module import PyBudaModule, wrap_module
from pybuda.module import Module, PyBudaModule, wrap_module
from pybuda.parameter import Parameter
from pybuda.pybudaglobal import state_changed, clear_state_changed
import pybuda.query as query
Expand Down Expand Up @@ -95,14 +95,14 @@ def generate_override_config(graph, balancer_solution, placer_solution, nop_inst

@dataclass
class CompileContext:
modules: List[PyBudaModule]
modules: List[Module]
graph_name: str
compiler_cfg: CompilerConfig
verify_cfg: VerifyConfig
microbatch_size: int
microbatch_count: int
inputs: Optional[Tuple[Union[Tensor, List[Any], Dict[str, Any]],...]] = None
loss_module: Optional[PyBudaModule] = None
inputs: Union[torch.Tensor, List[torch.Tensor]]
loss_module: Optional[Module] = None
optimizer: Optional[torch.optim.Optimizer] = None
training: bool = False
graph: Optional[Graph] = None
Expand Down Expand Up @@ -164,13 +164,35 @@ def calculate_grads(

def compile_main(
module: torch.nn.Module | tf.keras.Model | PyBudaModule,
sample_inputs: Optional[Tuple[Union[Tensor, List[Any], Dict[str, Any]],...]] = None,
sample_inputs: List[torch.Tensor],
module_name: Optional[str] = None,
loss: Optional[torch.nn.Module | PyBudaModule] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
):
) -> CompiledModel:
"""
Main entry point for compiling modules from different frameworks for Tenstorrent devices.
Parameters
----------
module: torch.nn.Module | tf.keras.Model | PyBudaModule
Torch, TensorFlow, or PyBuda module to compile
sample_inputs: List[torch.Tensor]
List of sample inputs for the module (used to infer shapes)
module_name: Optional[str]
Name of the module. If not provided, the class name of the provided module will be used.
loss: Optional[torch.nn.Module | PyBudaModule]
Loss module for training.
optimizer: Optional[torch.optim.Optimizer]
Optimizer for training.
Returns
-------
CompiledModel - Callable object that can be used to run the compiled module on device.
"""

assert isinstance(module, torch.nn.Module) or isinstance(module, tf.keras.Model) or isinstance(module, PyBudaModule), "Only PyTorch, TensorFlow, and PyBuda modules are supported."
Expand Down
4 changes: 2 additions & 2 deletions pybuda/pybuda/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
Parameters
----------
inputs: Tuple[Tensor, ...]
inputs: [Tensor, ...]
Input tensors
Returns
Expand All @@ -329,7 +329,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
return outputs

def forward(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
return self(*inputs)
return self(inputs)

def backward(self, loss_grad: torch.Tensor) -> List[torch.Tensor]:
assert self.compiled_graph_state.graph.training(), "Model not compiled for training."
Expand Down
8 changes: 2 additions & 6 deletions pybuda/test/mlir/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

# SPDX-License-Identifier: Apache-2.0

import os
import pytest
import torch
import torch.nn as nn

import tensorflow as tf

import pybuda
import pybuda.config

Expand Down Expand Up @@ -52,9 +48,9 @@ def forward(self, x1):
print(f"epoch: {epoch} loss: {loss}")
print(f"output.grad: {output[0].grad}")

golden_loss = loss_fn(golden, target)

loss_grad = output[0].grad
assert loss_grad is not None

print(f"loss grad: {loss_grad}")
grad = tt_model.backward(loss_grad)

Expand Down

0 comments on commit 89cae61

Please sign in to comment.