Skip to content

Commit

Permalink
FIX: beginner/onnx/onnx_registry_tutorial.py fails against 2.4 RC bin…
Browse files Browse the repository at this point in the history
…aries (#2950)

This PR updates the ONNX graphs which are optimized by onnxscript/optimizer

---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
titaiwangms and svekars authored Jul 22, 2024
1 parent 0dee5c9 commit 123dd5c
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 21 deletions.
1 change: 0 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"intermediate_source/flask_rest_api_tutorial",
"intermediate_source/text_to_speech_with_torchaudio",
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
"beginner_source/onnx/onnx_registry_tutorial", # reenable after 2941 is fixed.
"intermediate_source/torch_export_tutorial" # reenable after 2940 is fixed.
]

Expand Down
Binary file modified _static/img/onnx/custom_aten_add_function.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed _static/img/onnx/custom_aten_gelu_function.png
Binary file not shown.
Binary file modified _static/img/onnx/custom_aten_gelu_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 11 additions & 20 deletions beginner_source/onnx/onnx_registry_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def forward(self, input_x, input_y):
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_aten)
def custom_aten_add(input_x, input_y, alpha: float = 1.0):
alpha = opset18.CastLike(alpha, input_y)
input_y = opset18.Mul(input_y, alpha)
return opset18.Add(input_x, input_y)

Expand Down Expand Up @@ -130,9 +129,9 @@ def custom_aten_add(input_x, input_y, alpha: float = 1.0):
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"
# function node domain is empty because we use standard ONNX operators
assert onnx_program.model_proto.functions[0].node[3].domain == ""
assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""}
# function node name is the standard ONNX operator name
assert onnx_program.model_proto.functions[0].node[3].op_type == "Add"
assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"}


######################################################################
Expand Down Expand Up @@ -231,33 +230,25 @@ def custom_aten_gelu(input_x, approximate: str = "none"):


######################################################################
# Let's inspect the model and verify the model uses :func:`custom_aten_gelu` instead of
# :class:`aten::gelu`. Note the graph has one graph nodes for
# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function
# node for ``Gelu`` with namespace ``com.microsoft``.
# Let's inspect the model and verify the model uses op_type ``Gelu``
# from namespace ``com.microsoft``.
#
# .. note::
# :func:`custom_aten_gelu` does not exist in the graph because
# functions with fewer than three operators are inlined automatically.
#

# graph node domain is the custom domain we registered
assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_gelu"
# function node domain is the custom domain we registered
assert onnx_program.model_proto.functions[0].node[0].domain == "com.microsoft"
# function node name is the node name used in the function
assert onnx_program.model_proto.functions[0].node[0].op_type == "Gelu"
assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"


######################################################################
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron:
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron,
# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_model.png
# :width: 70%
# :align: center
#
# Inside the ``custom_aten_gelu`` function, we can see the ``Gelu`` node from module
# ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_function.png
#
# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model,
# and compare the results with PyTorch.
Expand Down

0 comments on commit 123dd5c

Please sign in to comment.