Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Address comment and update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Feb 8, 2023
1 parent 5c041c5 commit 1519bb9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
8 changes: 5 additions & 3 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ TVM_DLL Pass ToNonDataflow();
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Convert all reshape-like call_tir to VM reshape operator call.
* The VM reshape operator calls will be further lowered to a CreateView
* operation at runtime, instead of doing real data copy.
* \brief Convert all reshape-like call_tir whose corresponding binding
* vars are DataflowVars to relax.reshape operator calls. The relax.reshape
* calls will be lowered an external builtin function call in a subsequent
* pass, where the external builtin function does a CreateView operation
* at runtime, instead of doing real data copy.
* Here "reshape-like" includes reshape, expand_dims, flatten, etc.
*
* \return The Pass.
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:


def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
operation at runtime, instead of doing real data copy.
"""Convert all reshape-like call_tir whose corresponding binding
vars are DataflowVars to relax.reshape operator calls. The relax.reshape
calls will be lowered an external builtin function call in a subsequent
pass, where the external builtin function does a CreateView operation
at runtime, instead of doing real data copy.
Here "reshape-like" includes reshape, expand_dims, flatten, etc.
Returns
Expand Down
9 changes: 6 additions & 3 deletions src/relax/transform/rewrite_dataflow_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
/*!
* \file src/relax/transform/rewrite_dataflow_reshape.cc
* \brief Transform all reshape within dataflow block to a specialized reshape operator
* \brief Transform all reshape within dataflow block to a relax.reshape operator
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
Expand Down Expand Up @@ -75,8 +75,11 @@ class DataflowReshapeRewriter : public ExprMutator {
if (call->op != call_tir_op) {
return false;
}
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
const auto* func = mod_->functions.Get(gv).as<tir::PrimFuncNode>();
const auto* gv = call->args[0].as<GlobalVarNode>();
if (gv == nullptr) {
return false;
}
const auto* func = mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
ICHECK_NOTNULL(func);
return HasReshapePattern(GetRef<tir::PrimFunc>(func));
}
Expand Down

0 comments on commit 1519bb9

Please sign in to comment.