diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index f9ef2575f1..0d9b259b91 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -44,10 +44,10 @@ * \ / \ / * O O * - * The algorithm proceeds by assigning a "label", consisting of a pointer to the representative - * group and the name of the target backend, to each subexpression in the function according to + * The algorithm proceeds by assigning a group to each subexpression in the function according to * its dataflow. On encountering a call node whose callee is a composite function, we check the - * two conditions above to see if we can merge this call node into one of its parent groups. + * two conditions above to see if we can merge this call node into one of its parent groups, and + * if we can merge some of its parent groups. * * To detect cyclic dependencies between groups, we propagate dependency relations, both direct * and indirect ones, as we flow through the function. The propagation of indirect dependencies @@ -71,43 +71,32 @@ using relay::GraphPartitioner; namespace { -/*! \brief A label for a group of composite functions, consisting of the representative group and - * the target backend name */ -struct CompositeGroup { - GraphPartitioner::Group* representative; - String target; -}; +using Group = GraphPartitioner::Group; -/*! \brief Assign a "CompositeGroup" label to each subexpression in a function according to its - * dataflow, and returns a mapping from a subexpression to its representative group. */ -class CompositeGroupsBuilder : public MemoizedExprTranslator { +/*! \brief Assign group to each subexpression in a function according to its + * dataflow, and returns a mapping from a subexpression to its group. */ +class CompositeGroupsBuilder : public MemoizedExprTranslator { public: - using Group = GraphPartitioner::Group; using GroupMap = std::unordered_map; - using MemoizedExprTranslator::VisitExpr_; + using MemoizedExprTranslator::VisitExpr_; - CompositeGroupsBuilder(IRModule mod, support::Arena* arena) - : mod_(mod), arena_(arena), default_group_(CompositeGroup{nullptr, kDefaultTarget}) {} + CompositeGroupsBuilder(IRModule mod, support::Arena* arena) : mod_(mod), arena_(arena) {} GroupMap Run(Function func) { for (const auto& param : func->params) { - memo_[param] = CompositeGroup{nullptr, kDefaultTarget}; + memo_[param] = arena_->make(); } VisitExpr(func->body); GroupMap group_map; for (const auto& [expr, group] : memo_) { - if (group.representative) { - group_map[expr.get()] = group.representative; - } else { - group_map[expr.get()] = arena_->make(); - } + group_map[expr.get()] = group->FindRoot(); } return group_map; } - CompositeGroup VisitBinding(const Binding& binding) { + Group* VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { return VisitBinding_(node); } else { @@ -115,60 +104,57 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } } - CompositeGroup VisitBindingBlock_(const BindingBlockNode* block) { + void VisitBindingBlock_(const BindingBlockNode* block) { for (Binding binding : block->bindings) { VisitBinding(binding); } - return default_group_; } - CompositeGroup VisitBindingBlock_(const DataflowBlockNode* block) { + void VisitBindingBlock_(const DataflowBlockNode* block) { for (Binding binding : block->bindings) { VisitBinding(binding); } - return CompositeGroup{nullptr, kDefaultTarget}; } - CompositeGroup VisitBindingBlock(const BindingBlock& block) { + void VisitBindingBlock(const BindingBlock& block) { if (const auto* node = block.as()) { - return VisitBindingBlock_(node); + VisitBindingBlock_(node); } else if (const auto* node = block.as()) { - return VisitBindingBlock_(node); + VisitBindingBlock_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); } } - CompositeGroup VisitExpr_(const SeqExprNode* op) { + Group* VisitExpr_(const SeqExprNode* op) { for (BindingBlock block : op->blocks) { VisitBindingBlock(block); } return VisitExpr(op->body); } - CompositeGroup VisitExpr_(const CallNode* call) { - // Only a call to a composite function is relevant. - if (auto codegen_name = GetCodegenName(call->op)) { - // Designate one of the parent groups as the "representative" group. - auto rep_group = GetRepresentative(call->args, *codegen_name); - - if (rep_group->num_nodes != 0) { - // Merge other parent groups into the representative group. - for (const auto& arg : call->args) { - auto& arg_group = memo_[arg]; - if (arg_group.target == codegen_name && arg_group.representative != rep_group) { - rep_group->num_nodes += arg_group.representative->num_nodes; - arg_group.representative->num_nodes = 0; - arg_group.representative = rep_group; - } - } - } + Group* VisitExpr_(const CallNode* call) { + std::vector groups_to_merge = GetGroupsToMerge(call); + Group* group; - // Merge this call node into the representative group. - ++rep_group->num_nodes; - return CompositeGroup{rep_group, *codegen_name}; + if (groups_to_merge.size() == 0) { + // Create new group if there is nothing to merge with + group = CreateNewGroup(call); + } else { + auto it = groups_to_merge.cbegin(); + // Assign the first mergable group to current node + // to reduce the number of groups created + group = *it++; + group->num_nodes += 1; + + // Merge all groups + for (; it != groups_to_merge.cend(); ++it) { + MergeGroup(*it, group); + } } - return default_group_; + + UpdateGroupDependencies(group, call->args); + return group; } private: @@ -179,71 +165,116 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return composite_name.substr(0, delim_pos); } - std::optional GetCodegenName(const Expr& callee) { + Optional GetCodegenName(const Expr& callee) { auto const* gvar = callee.as(); if (!gvar) { - return std::nullopt; + return NullOpt; } auto composite_name_opt = mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); if (!composite_name_opt) { - return std::nullopt; + return NullOpt; } return GetCodegenName(composite_name_opt.value()); } - Group* GetRepresentative(const Array& args, String codegen_name) { - Group* rep = nullptr; - std::unordered_set parent_deps; + Optional GetCodegenName(Group* group) { + return Downcast>(group->attrs.Get(attr::kCodegen)); + } + Group* CreateNewGroup(const CallNode* call) { + Group* group = arena_->make(); + if (Optional codegen_name = GetCodegenName(call->op)) { + group->attrs.Set(attr::kCodegen, codegen_name.value()); + } + return group; + } + + void MergeGroup(Group* from, Group* to) { + ICHECK_EQ(GetCodegenName(from), GetCodegenName(to)); + + Group* from_root = from->FindRoot(); + Group* to_root = to->FindRoot(); + if (from_root == to_root) { + return; + } + + from_root->parent = to_root; + to_root->num_nodes += from_root->num_nodes; + + // Update the group_deps_, maintaining the invariant that + // all groups in the map are root groups. + group_deps_[to_root].merge(group_deps_[from_root]); + group_deps_.erase(from_root); + for (auto& it : group_deps_) { + if (it.second.count(from_root)) { + it.second.erase(from_root); + it.second.insert(to_root); + } + } + } + + std::unordered_set GetParentGroupDependencies(const Array& args) { // Collect groups that parent groups depend on + std::unordered_set dependencies; + for (const auto& arg : args) { - for (auto parent_dep : group_deps_[memo_[arg].representative]) { - parent_deps.insert(parent_dep); + for (auto dep : group_deps_[memo_[arg]->FindRoot()]) { + dependencies.insert(dep); } } + return dependencies; + } + + void UpdateGroupDependencies(Group* group, const Array& args) { + Group* group_root = group->FindRoot(); + for (const auto& arg : args) { - auto arg_group = memo_[arg]; - if (arg_group.target == codegen_name && !parent_deps.count(arg_group.representative)) { - // If there is a parent group with the same target, which none of the parent dependency - // groups depends on, merging "this" call node into the parent group will not form a cyclic - // dependency. - rep = arg_group.representative; + auto arg_group_root = memo_[arg]->FindRoot(); + if (arg_group_root == group_root) { + // If arg and the current node are in the same group, + // there is nothing to update. + continue; + } + // Add the group of arg as dependency + group_deps_[group_root].insert(arg_group_root); + // Propagate dependencies of arg + for (auto dep : group_deps_[arg_group_root]) { + group_deps_[group_root].insert(dep); } } + } - if (rep == nullptr) { - // If we do not find a valid representative parent group, make a new group. - // This can happen if all arguments are function parameters or belong to other targets. - rep = arena_->make(); - // Set num_nodes to 0 to signify that this representative groups has been newly created. - rep->num_nodes = 0; - rep->attrs.Set(attr::kCodegen, codegen_name); + std::vector GetGroupsToMerge(const CallNode* call) { + Optional codegen_name = GetCodegenName(call->op); + if (!codegen_name.defined()) { + return {}; } - // Record direct parent dependencies. - for (const auto& arg : args) { + std::vector groups_to_merge; + std::unordered_set parent_dependencies = GetParentGroupDependencies(call->args); + + for (const auto& arg : call->args) { auto arg_group = memo_[arg]; - if (arg_group.target != codegen_name) { - group_deps_[rep].insert(arg_group.representative); + Optional arg_codegen_name = GetCodegenName(arg_group); + if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) { + // If there is a parent group with the same target, which none of the parent dependency + // groups depends on, merging "this" call node into the parent group will not form a cyclic + // dependency. + groups_to_merge.push_back(arg_group); } } - // Propagate parent dependencies. - for (auto parent_dep : parent_deps) { - group_deps_[rep].insert(parent_dep); - } - - return rep; + return groups_to_merge; } - const String kDefaultTarget = "default"; IRModule mod_; support::Arena* arena_; - CompositeGroup default_group_; + // Map from group to its dependencies. All groups in this map, whether it's + // the key or in value, should be root node (that is, group->parent == nullptr). std::unordered_map> group_deps_; }; @@ -257,6 +288,7 @@ class CompositeInliner : public ExprMutator { using ExprMutator::VisitExpr_; Function Run(Function func) { + inlined_functions_ = Map(); auto new_body = VisitExpr(func->body); auto new_func = Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); @@ -266,16 +298,22 @@ class CompositeInliner : public ExprMutator { Expr VisitExpr_(const CallNode* call) { if (call->op->IsInstance()) { auto gvar = Downcast(call->op); - auto func = CopyWithNewVars(Downcast(mod_->Lookup(gvar))); + auto func = Downcast(mod_->Lookup(gvar)); + if (func->GetAttr(attr::kComposite)) { - return Call(func, call->args); + if (!inlined_functions_.count(func)) { + inlined_functions_.Set(func, CopyWithNewVars(func)); + } + return Call(inlined_functions_[func], call->args); } } + return ExprMutator::VisitExpr_(call); } private: IRModule mod_; + Map inlined_functions_; }; } // namespace diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index b168cd224c..8577a4d93c 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest -import tvm +import tvm from tvm import relax from tvm.script import relax as R @@ -143,6 +143,171 @@ def lv11( return gv3 +@tvm.script.ir_module +class Diamond: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d(data, weight) + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_relu(lv2) + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_gelu(lv2) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_add(lv3, lv4) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + lv: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu( + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_add( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class Diamond_merged: + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + R.output(gv4) + return gv4 + + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) + + @R.function + def lv1( + lv11: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) + R.output(gv1) + return gv1 + + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2) + + @R.function + def lv21( + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv4) + R.output(gv) + return gv + + lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2) + + @R.function + def lv31( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41) + R.output(gv2) + return gv2 + + @R.function + def main( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # block 0 + with R.dataflow(): + gv5: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2) + R.output(gv5) + return gv5 + + @tvm.script.ir_module class Diamond_cyclic_dep: @R.function @@ -327,7 +492,9 @@ def main( with R.dataflow(): lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1) lv2: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(x2) - gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv1, lv2) + lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv1) + lv4: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv3, lv4) R.output(gv1) return gv1 @@ -365,62 +532,197 @@ def fused_relax_add( @tvm.script.ir_module class MultipleProducers_merged: @R.function - def main( + def fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add( x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 with R.dataflow(): - gv: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu_relax_nn_gelu_relax_add( - x1, x2 - ) + + @R.function + def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + lv1: R.Tensor((10,), dtype="float32") = lv(x1) + + @R.function + def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + lv2: R.Tensor((10,), dtype="float32") = lv11(x2) + lv3: R.Tensor((10,), dtype="float32") = lv(lv1) + lv4: R.Tensor((10,), dtype="float32") = lv11(lv2) + + @R.function + def lv21( + lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1) + R.output(gv) + return gv + + gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4) + R.output(gv1) + return gv1 + + @R.function + def main( + x12: R.Tensor((10,), dtype="float32"), x22: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # block 0 + with R.dataflow(): + gv4: R.Tensor( + (10,), dtype="float32" + ) = fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(x12, x22) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class MultipleProducersCyclic: + @R.function + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + with R.dataflow(): + lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv1) + lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv1, lv3) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add( + lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1) R.output(gv) return gv + +@tvm.script.ir_module +class MultipleProducersCyclic_merged: @R.function - def fused_relax_nn_relu_relax_nn_gelu_relax_add( - x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu1(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv) + gv: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu_relax_add(lv2, lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu1( + x11: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): + # function attr dict R.func_attr( - { - "Primitive": 1, - "Codegen": "compiler_A", - "global_symbol": "fused_relax_nn_relu_relax_nn_gelu_relax_add", - } + {"Codegen": "compiler_A", "Primitive": 1, "global_symbol": "fused_relax_nn_relu1"} ) + # block 0 with R.dataflow(): @R.function - def lv(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 with R.dataflow(): gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111) R.output(gv2) return gv2 - lv1: R.Tensor((10,), dtype="float32") = lv(x11) + gv1: R.Tensor((10,), dtype="float32") = lv1(x11) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_gelu_relax_add( + lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): @R.function - def lv11(x211: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 with R.dataflow(): - gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x211) + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) R.output(gv3) return gv3 - lv2: R.Tensor((10,), dtype="float32") = lv11(x21) + lv3: R.Tensor((10,), dtype="float32") = lv12(lv21) @R.function - def lv21( - lv3: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + def lv22( + lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): + # function attr dict R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 with R.dataflow(): - gv1: R.Tensor((10,), dtype="float32") = R.add(lv3, gelu1) - R.output(gv1) - return gv1 + gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1) + R.output(gv4) + return gv4 - gv4: R.Tensor((10,), dtype="float32") = lv21(lv1, lv2) - R.output(gv4) - return gv4 + gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3) + R.output(gv5) + return gv5 @tvm.script.ir_module @@ -666,7 +968,7 @@ def lv1( def check(mod, expected): partitioned = relax.transform.MergeCompositeFunctions()(mod) - tvm.ir.structural_equal(partitioned, expected) + tvm.ir.assert_structural_equal(partitioned, expected) def test_conv2d_relu_x2(): @@ -689,9 +991,25 @@ def test_diamond_cyclic_dep(): check(Diamond_cyclic_dep, Diamond_cyclic_dep_merged) +def test_diamond(): + """ + O = Offloaded to A + + O O + / \\ / \\ + O O --> O O + \\ / \\ / + O O + + """ + check(Diamond, Diamond_merged) + + def test_merge_producers(): """ Test merging multiple producer groups into a single representative group. + O O + | | O O \\ / O @@ -699,6 +1017,21 @@ def test_merge_producers(): check(MultipleProducers, MultipleProducers_merged) +def test_merge_producers_cyclic_dep(): + """ + Test when multiple producer groups being blocked to merge due to circular dependency + in the result. + O + |\\ + | X + | | + | O + |/ + O + """ + check(MultipleProducersCyclic, MultipleProducersCyclic_merged) + + def test_merge_compiler_regions_example(): """ A tricky example from https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830