Skip to content

Commit

Permalink
Patched docs for torch_compile_tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
ignaciobartol committed Jun 16, 2024
1 parent 63f987d commit 204f9fc
Showing 1 changed file with 87 additions and 2 deletions.
89 changes: 87 additions & 2 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,35 @@ def foo(x, y):

######################################################################
# Alternatively, we can decorate the function.
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
print(opt_foo2(t1, t2))

# When using the decorator approach, nested function calls within the decorated
# function will also be compiled.

def nested_function(x):
return torch.sin(x)

@torch.compile
def outer_function(x, y):
a = nested_function(x)
b = torch.cos(y)
return a + b

print(outer_function(t1, t2))

######################################################################
# We can also optimize ``torch.nn.Module`` instances.

t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -94,7 +112,74 @@ def forward(self, x):

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
print(opt_mod(t))

# In the same fashion, when compiling a module all sub-modules and methods
# within it are also compiled.

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)

def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

######################################################################
# We can also disable some functions from being compiled by using
# `torch.compiler.disable`

@torch.compiler.disable
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
return torch.complex(real, imag)

def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)

# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)

######################################################################
# Best Practices and Recommendations
# ----------------------------------
#
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
#
# When you use ``torch.compile``, the compiler will try to recursively inline
# and compile every function call inside the target function or module.
#
# This includes:
#
# - **Nested function calls:** All functions called within the decorated or compiled function will also be compiled.
#
# - **Nested modules:** If a ``torch.nn.Module`` is compiled, all sub-modules and functions within the module are also compiled.
#
# **Best Practices:**
#
# 1. **Modular Testing:** Test individual functions and modules with ``torch.compile``
# before integrating them into larger models to isolate potential issues.
#
# 2. **Disable Compilation Selectively:** If certain functions or sub-modules
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
# managers to recursively exclude them from compilation.
#
# 3. **Compile Leaf Functions First:** In complex models with multiple nested
# functions and modules, start by compiling the leaf functions or modules first.
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.

######################################################################
# Demonstrating Speedups
Expand Down

0 comments on commit 204f9fc

Please sign in to comment.