Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update aoti tutorial #3224

Open
wants to merge 1 commit into
base: 2.6-RC-TEST
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 123 additions & 71 deletions recipes_source/torch_export_aoti_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
.. meta::
:description: An end-to-end example of how to use AOTInductor for Python runtime.
:keywords: torch.export, AOTInductor, torch._inductor.aot_compile, torch._export.aot_load
:keywords: torch.export, AOTInductor, torch._inductor.aoti_compile_and_package, aot_compile, torch._export.aoti_load_package

``torch.export`` AOTInductor Tutorial for Python runtime (Beta)
===============================================================
Expand All @@ -14,19 +14,18 @@
#
# .. warning::
#
# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
# ``torch._inductor.aoti_compile_and_package`` and
# ``torch._inductor.aoti_load_package`` are in Beta status and are subject
# to backwards compatibility breaking changes. This tutorial provides an
# example of how to use these APIs for model deployment using Python
# runtime.
#
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
# to do Ahead-of-Time compilation of PyTorch exported models by creating
# a shared library that can be run in a non-Python environment.
#
#
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.
# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a
# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
# ``max-autotune`` mode which can take some time to execute.
# It has been shown `previously
# <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how
# AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported
# models by creating an artifact that can be run in a non-Python environment.
# In this tutorial, you will learn an end-to-end example of how to use
# AOTInductor for Python runtime.
#
# **Contents**
#
Expand All @@ -36,115 +35,169 @@
######################################################################
# Prerequisites
# -------------
# * PyTorch 2.4 or later
# * PyTorch 2.6 or later
# * Basic understanding of ``torch.export`` and AOTInductor
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial

######################################################################
# What you will learn
# ----------------------
# * How to use AOTInductor for python runtime.
# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
# * When do you use AOTInductor for python runtime
# * How to use AOTInductor for Python runtime.
# * How to use :func:`torch._inductor.aoti_compile_and_package` along with :func:`torch.export.export` to generate a compiled artifact
# * How to load and run the artifact in a Python runtime using :func:`torch._export.aot_load`.
# * When to you use AOTInductor with a Python runtime

######################################################################
# Model Compilation
# -----------------
#
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
# exported PyTorch program using :func:`torch._inductor.aot_compile`.
# We will use the TorchVision pretrained ``ResNet18`` model as an example.
#
# .. note::
# The first step is to export the model to a graph representation using
# :func:`torch.export.export`. To learn more about using this function, you can
# check out the `docs <https://pytorch.org/docs/main/export.html>`_ or the
# `tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`_.
#
# This API also supports :func:`torch.compile` options like ``mode``
# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
# Once we have exported the PyTorch model and obtained an ``ExportedProgram``,
# we can apply :func:`torch._inductor.aoti_compile_and_package` to AOTInductor
# to compile the program to a specified device, and save the generated contents
# into a ".pt2" artifact.
#
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__

# .. note::
#
# This API supports the same available options that :func:`torch.compile`
# has, such as ``mode`` and ``max_autotune`` (for those who want to enable
# CUDA graphs and leverage Triton based matrix multiplications and
# convolutions)

import os
import torch
import torch._inductor
from torchvision.models import ResNet18_Weights, resnet18

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.inference_mode():
inductor_configs = {}

# Specify the generated shared library path
aot_compile_options = {
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
}
if torch.cuda.is_available():
device = "cuda"
aot_compile_options.update({"max_autotune": True})
inductor_configs["max_autotune"] = True
else:
device = "cpu"

model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

# min=2 is not a bug and is explained in the 0/1 Specialization Problem
batch_dim = torch.export.Dim("batch", min=2, max=32)
exported_program = torch.export.export(
model,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
so_path = torch._inductor.aot_compile(
exported_program.module(),
example_inputs,
# Specify the generated shared library path
options=aot_compile_options
path = torch._inductor.aoti_compile_and_package(
exported_program,
package_path=os.path.join(os.getcwd(), "resnet18.pt2"),
inductor_configs=inductor_configs
)

######################################################################
# The result of :func:`aoti_compile_and_package` is an artifact "resnet18.pt2"
# which can be loaded and executed in Python and C++.
#
# The artifact itself contains a bunch of AOTInductor generated code, such as
# a generated C++ runner file, a shared library compiled from the C++ file, and
# CUDA binary files, aka cubin files, if optimizing for CUDA.
#
# Structure-wise, the artifact is a structured ``.zip`` file, with the following
# specification:
#
# .. code::
# .
# ├── archive_format
# ├── version
# ├── data
# │ ├── aotinductor
# │ │ └── model
# │ │ ├── xxx.cpp # AOTInductor generated cpp file
# │ │ ├── xxx.so # AOTInductor generated shared library
# │ │ ├── xxx.cubin # Cubin files (if running on CUDA)
# │ │ └── xxx_metadata.json # Additional metadata to save
# │ ├── weights
# │ │ └── TBD
# │ └── constants
# │ └── TBD
# └── extra
# └── metadata.json
#
# We can use the following command to inspect the artifact contents:
#
# .. code:: bash
#
# $ unzip -l resnet18.pt2
#
# .. code::
#
# Archive: resnet18.pt2
# Length Date Time Name
# --------- ---------- ----- ----
# 1 01-08-2025 16:40 version
# 3 01-08-2025 16:40 archive_format
# 10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin
# 17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin
# 16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin
# 17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin
# 10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin
# 14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin
# 11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin
# 10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin
# 14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin
# 11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin
# 11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin
# 15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin
# 25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin
# 139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp
# 27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json
# 47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so
# --------- -------
# 47523148 18 files


######################################################################
# Model Inference in Python
# -------------------------
#
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
# of the shared library and the device where it should be loaded.
# To load and run the artifact in Python, we can use :func:`torch._inductor.aoti_load_package`.
#
# .. note::
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
# :func:`torch.export.export`.


import os
import torch
import torch._inductor

device = "cuda" if torch.cuda.is_available() else "cpu"
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
model_path = os.path.join(os.getcwd(), "resnet18.pt2")

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

with torch.inference_mode():
output = model(example_inputs)
output = compiled_model(example_inputs)


######################################################################
# When to use AOTInductor for Python Runtime
# ------------------------------------------
# When to use AOTInductor with a Python Runtime
# ---------------------------------------------
#
# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
# model deployment using Python.
# There are mainly two reasons why you would use AOTInductor Python Runtime:
# There are mainly two reasons why one would use AOTInductor with a Python Runtime:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate of the last section?

#
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
# versioning for deployments and tracking model performance over time.
# - ``torch._inductor.aoti_compile_and_package`` generates a singular
# serialized artifact. This is useful for model versioning for deployments
# and tracking model performance over time.
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
# cost associated with the first compilation. Your deployment needs to account for the
# compilation time taken for the first inference. With AOTInductor, the compilation is
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
# would only load the shared library using ``torch._export.aot_load`` and run inference.
# cost associated with the first compilation. Your deployment needs to
# account for the compilation time taken for the first inference. With
# AOTInductor, the compilation is done ahead of time using
# ``torch.export.export`` and ``torch._inductor.aoti_compile_and_package``.
# At deployment time, after loading the model, running inference does not
# have any additional cost.
#
#
# The section below shows the speedup achieved with AOTInductor for first inference
Expand Down Expand Up @@ -185,7 +238,7 @@ def timed(fn):

torch._dynamo.reset()

model = torch._export.aot_load(model_so_path, device)
model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
Expand Down Expand Up @@ -217,8 +270,7 @@ def timed(fn):
# ----------
#
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
# generating a shared library and running it within a Python environment, even with dynamic shape
# considerations and device-specific optimizations. We also looked at the advantage of using
# compiling and loading a pretrained ``ResNet18`` model. This process
# demonstrates the practical application of generating a compiled artifact and
# running it within a Python environment. We also looked at the advantage of using
# AOTInductor in model deployments, with regards to speed up in first inference time.