Skip to content

Commit

Permalink
[compile] fix compile forge module (#162)
Browse files Browse the repository at this point in the history
Fixes a bug in compile, where non-forge (non-pybuda) modules
where compiled, but the forge modules where ignored.

Also, adding some type aliases and small refactor.

Issue #150
  • Loading branch information
pilkicTT authored Aug 26, 2024
1 parent 8d88845 commit f9b0bc2
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 20 deletions.
5 changes: 4 additions & 1 deletion pybuda/pybuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def set_home_paths():
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

from .module import Module, PyTorchModule, PyBudaModule, TFGraphDefModule, OnnxModule, JaxModule, TFLiteModule
from .compile import pybuda_compile_torch, compile_main as compile
from .torch_compile import compile_torch
from .compiled_graph_state import CompiledGraphState
from .config import CompilerConfig, CompileDepth, set_configuration_options, set_epoch_break, set_chip_break, override_op_size, PerfTraceLevel, insert_buffering_nop, insert_nop, _internal_insert_fj_buffering_nop, override_dram_queue_placement, configure_mixed_precision
Expand All @@ -47,6 +46,9 @@ def set_home_paths():
import pybuda.op as op
import pybuda.transformers

import pybuda.typing
from .compile import pybuda_compile_torch, compile_main as compile

# Torch backend registration
# TODO: move this in a separate file / module.
from torch._dynamo.backends.registry import _BACKENDS
Expand All @@ -58,3 +60,4 @@ def set_home_paths():
if "tt" in _BACKENDS:
del _BACKENDS["tt"]
register_backend(compile_torch, "tt")

55 changes: 36 additions & 19 deletions pybuda/pybuda/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pybuda.pybudaglobal import state_changed, clear_state_changed
import pybuda.query as query
from pybuda.tensor import Tensor, to_pt_tensors
from pybuda.typing import *
from pybuda.verify import VerifyConfig, do_verify, _generate_random_losses, _run_pytorch_backward


Expand Down Expand Up @@ -163,7 +164,7 @@ def calculate_grads(
return losses

def compile_main(
module: torch.nn.Module | tf.keras.Model | PyBudaModule,
module: AnyModule,
sample_inputs: List[torch.Tensor],
module_name: Optional[str] = None,
loss: Optional[torch.nn.Module | PyBudaModule] = None,
Expand All @@ -174,7 +175,7 @@ def compile_main(
Parameters
----------
module: torch.nn.Module | tf.keras.Model | PyBudaModule
module: AnyModule
Torch, TensorFlow, or PyBuda module to compile
sample_inputs: List[torch.Tensor]
Expand All @@ -195,7 +196,7 @@ def compile_main(
"""

assert isinstance(module, torch.nn.Module) or isinstance(module, tf.keras.Model) or isinstance(module, PyBudaModule), "Only PyTorch, TensorFlow, and PyBuda modules are supported."
assert isinstance(module, AnyModule), "Only PyTorch, TensorFlow, and PyBuda modules are supported."

compiler_cfg = _get_global_compiler_config()
compiler_cfg.apply_env_config_overrides()
Expand Down Expand Up @@ -569,25 +570,14 @@ def generate_initial_graph(context: CompileContext) -> CompileDepth:
modules_ = []
if context.compiler_cfg.compile_tvm_to_python and context.graph is None:
module_inputs = context.inputs
for index, module in enumerate(context.modules):
for module in context.modules:
if not isinstance(module, PyBudaModule):
from .tvm_to_python import generate_pybuda_module
prev_state = state_changed()
if module_inputs is None:
logger.error("No inputs provided for module {}", module.name)
assert False
modules, dev_types, module_inputs = generate_pybuda_module(module, to_pt_tensors(module_inputs), context.compiler_cfg, module.name, context.verify_cfg,)
assert len(modules) == 1, "Attemping to load split model onto single devices"
module, module_inputs = convert_to_forge_module(module, module_inputs, context.compiler_cfg, context.verify_cfg)
assert isinstance(module, PyBudaModule)

modules_.append(modules[0])
if index == 0:
context.inputs = module_inputs
context.inputs = module_inputs

if not(prev_state):
clear_state_changed()

if isinstance(module_inputs, Tensor):
module_inputs = (module_inputs,) # Force a tuple
modules_.append(module)

if context.graph is None:
context.graph, context.outputs, context.intermediate_tensors, context.inputs, _ = generate_graph(modules_, *context.inputs, return_intermediate=context.verify_cfg.intermediates, graph_name=context.graph_name, compiler_cfg=context.compiler_cfg, target_tensors=context.targets)
Expand Down Expand Up @@ -857,6 +847,33 @@ def finish_compile(context: CompileContext) -> CompileDepth:

return CompileDepth.FULL

def convert_to_forge_module(module: AnyModule, module_inputs: Union[AnyTensor, List[AnyTensor]], compiler_cfg: CompilerConfig, verify_cfg: VerifyConfig) -> PyBudaModule:
"""
Converts given module to a Forge module, along with the module_inputs (which will be converted to Forge tensors).
Returns
-------
PyBudaModule, Tuple[Tensor, ...]
"""

from .tvm_to_python import generate_pybuda_module
prev_state = state_changed()

if module_inputs is None:
logger.error("No inputs provided for module {}", module.name)
assert False

forge_module, dev_types, module_inputs = generate_pybuda_module(module, to_pt_tensors(module_inputs), compiler_cfg, module.name, verify_cfg,)
assert len(forge_module) == 1, "Attemping to load split model onto single devices"

if not(prev_state):
clear_state_changed()

if isinstance(module_inputs, Tensor):
module_inputs = (module_inputs,) # Force a tuple

return forge_module[0], module_inputs

def generate_graph(
modules,
*inputs: Tensor,
Expand Down
15 changes: 15 additions & 0 deletions pybuda/pybuda/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import tensorflow as tf

from .module import PyBudaModule
from .tensor import Tensor

FrameworkModule = torch.nn.Module | tf.keras.Model
FrameworkTensor = torch.Tensor | tf.Tensor
AnyModule = FrameworkModule | PyBudaModule
AnyTensor = FrameworkTensor | Tensor

27 changes: 27 additions & 0 deletions pybuda/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pybuda
import pybuda.config
from pybuda.tensor import to_buda_tensors, to_pt_tensors

def test_torch():
class Add(nn.Module):
Expand Down Expand Up @@ -59,3 +60,29 @@ def call(self, x1, x2):
print(f"output: {output}")
if not torch.allclose(output[0], golden, rtol=1e-1):
raise ValueError("Output does not match the golden output")

def test_forge():
class ForgeAdd(pybuda.PyBudaModule):
def __init__(self):
super().__init__("PyBudaTest")

def forward(self, x, y):
return pybuda.op.Add("", x, y)

inputs = to_buda_tensors([torch.rand(1, 32, 32), torch.rand(1, 32, 32)])

model = ForgeAdd()
golden = model(*inputs)

compiled_model = pybuda.compile(model, sample_inputs=inputs)

# Issue #161 : currently, we expect inputs to be torch tensors
inputs = to_pt_tensors(inputs)
output = compiled_model(*inputs)

print(f"golden: {golden}")
print(f"output: {output}")

if not torch.allclose(output[0], golden.to_pytorch(), rtol=1e-1):
raise ValueError("Output does not match the golden output")

0 comments on commit f9b0bc2

Please sign in to comment.