Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[BYOC] Fix problems in MergeCompositeFunctions (#411)
Browse files Browse the repository at this point in the history
* Fix problems in merge_composite_functions

* Update doc

* Unify private methods and update code
  • Loading branch information
yelite authored Feb 8, 2023
1 parent 1123675 commit cb1523b
Show file tree
Hide file tree
Showing 2 changed files with 485 additions and 114 deletions.
212 changes: 125 additions & 87 deletions src/relax/transform/merge_composite_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
* \ / \ /
* O O
*
* The algorithm proceeds by assigning a "label", consisting of a pointer to the representative
* group and the name of the target backend, to each subexpression in the function according to
* The algorithm proceeds by assigning a group to each subexpression in the function according to
* its dataflow. On encountering a call node whose callee is a composite function, we check the
* two conditions above to see if we can merge this call node into one of its parent groups.
* two conditions above to see if we can merge this call node into one of its parent groups, and
* if we can merge some of its parent groups.
*
* To detect cyclic dependencies between groups, we propagate dependency relations, both direct
* and indirect ones, as we flow through the function. The propagation of indirect dependencies
Expand All @@ -71,104 +71,90 @@ using relay::GraphPartitioner;

namespace {

/*! \brief A label for a group of composite functions, consisting of the representative group and
* the target backend name */
struct CompositeGroup {
GraphPartitioner::Group* representative;
String target;
};
using Group = GraphPartitioner::Group;

/*! \brief Assign a "CompositeGroup" label to each subexpression in a function according to its
* dataflow, and returns a mapping from a subexpression to its representative group. */
class CompositeGroupsBuilder : public MemoizedExprTranslator<CompositeGroup> {
/*! \brief Assign group to each subexpression in a function according to its
* dataflow, and returns a mapping from a subexpression to its group. */
class CompositeGroupsBuilder : public MemoizedExprTranslator<Group*> {
public:
using Group = GraphPartitioner::Group;
using GroupMap = std::unordered_map<const Object*, Group*>;
using MemoizedExprTranslator<CompositeGroup>::VisitExpr_;
using MemoizedExprTranslator<Group*>::VisitExpr_;

CompositeGroupsBuilder(IRModule mod, support::Arena* arena)
: mod_(mod), arena_(arena), default_group_(CompositeGroup{nullptr, kDefaultTarget}) {}
CompositeGroupsBuilder(IRModule mod, support::Arena* arena) : mod_(mod), arena_(arena) {}

GroupMap Run(Function func) {
for (const auto& param : func->params) {
memo_[param] = CompositeGroup{nullptr, kDefaultTarget};
memo_[param] = arena_->make<Group>();
}
VisitExpr(func->body);

GroupMap group_map;
for (const auto& [expr, group] : memo_) {
if (group.representative) {
group_map[expr.get()] = group.representative;
} else {
group_map[expr.get()] = arena_->make<Group>();
}
group_map[expr.get()] = group->FindRoot();
}

return group_map;
}

CompositeGroup VisitBinding(const Binding& binding) {
Group* VisitBinding(const Binding& binding) {
if (const auto* node = binding.as<VarBindingNode>()) {
return VisitBinding_(node);
} else {
LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
}
}

CompositeGroup VisitBindingBlock_(const BindingBlockNode* block) {
void VisitBindingBlock_(const BindingBlockNode* block) {
for (Binding binding : block->bindings) {
VisitBinding(binding);
}
return default_group_;
}

CompositeGroup VisitBindingBlock_(const DataflowBlockNode* block) {
void VisitBindingBlock_(const DataflowBlockNode* block) {
for (Binding binding : block->bindings) {
VisitBinding(binding);
}
return CompositeGroup{nullptr, kDefaultTarget};
}

CompositeGroup VisitBindingBlock(const BindingBlock& block) {
void VisitBindingBlock(const BindingBlock& block) {
if (const auto* node = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(node);
VisitBindingBlock_(node);
} else if (const auto* node = block.as<BindingBlockNode>()) {
return VisitBindingBlock_(node);
VisitBindingBlock_(node);
} else {
LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
}
}

CompositeGroup VisitExpr_(const SeqExprNode* op) {
Group* VisitExpr_(const SeqExprNode* op) {
for (BindingBlock block : op->blocks) {
VisitBindingBlock(block);
}
return VisitExpr(op->body);
}

CompositeGroup VisitExpr_(const CallNode* call) {
// Only a call to a composite function is relevant.
if (auto codegen_name = GetCodegenName(call->op)) {
// Designate one of the parent groups as the "representative" group.
auto rep_group = GetRepresentative(call->args, *codegen_name);

if (rep_group->num_nodes != 0) {
// Merge other parent groups into the representative group.
for (const auto& arg : call->args) {
auto& arg_group = memo_[arg];
if (arg_group.target == codegen_name && arg_group.representative != rep_group) {
rep_group->num_nodes += arg_group.representative->num_nodes;
arg_group.representative->num_nodes = 0;
arg_group.representative = rep_group;
}
}
}
Group* VisitExpr_(const CallNode* call) {
std::vector<Group*> groups_to_merge = GetGroupsToMerge(call);
Group* group;

// Merge this call node into the representative group.
++rep_group->num_nodes;
return CompositeGroup{rep_group, *codegen_name};
if (groups_to_merge.size() == 0) {
// Create new group if there is nothing to merge with
group = CreateNewGroup(call);
} else {
auto it = groups_to_merge.cbegin();
// Assign the first mergable group to current node
// to reduce the number of groups created
group = *it++;
group->num_nodes += 1;

// Merge all groups
for (; it != groups_to_merge.cend(); ++it) {
MergeGroup(*it, group);
}
}
return default_group_;

UpdateGroupDependencies(group, call->args);
return group;
}

private:
Expand All @@ -179,71 +165,116 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator<CompositeGroup> {
return composite_name.substr(0, delim_pos);
}

std::optional<String> GetCodegenName(const Expr& callee) {
Optional<String> GetCodegenName(const Expr& callee) {
auto const* gvar = callee.as<GlobalVarNode>();
if (!gvar) {
return std::nullopt;
return NullOpt;
}

auto composite_name_opt =
mod_->Lookup(GetRef<GlobalVar>(gvar))->GetAttr<String>(attr::kComposite);
if (!composite_name_opt) {
return std::nullopt;
return NullOpt;
}

return GetCodegenName(composite_name_opt.value());
}

Group* GetRepresentative(const Array<Expr>& args, String codegen_name) {
Group* rep = nullptr;
std::unordered_set<Group*> parent_deps;
Optional<String> GetCodegenName(Group* group) {
return Downcast<Optional<String>>(group->attrs.Get(attr::kCodegen));
}

Group* CreateNewGroup(const CallNode* call) {
Group* group = arena_->make<Group>();
if (Optional<String> codegen_name = GetCodegenName(call->op)) {
group->attrs.Set(attr::kCodegen, codegen_name.value());
}
return group;
}

void MergeGroup(Group* from, Group* to) {
ICHECK_EQ(GetCodegenName(from), GetCodegenName(to));

Group* from_root = from->FindRoot();
Group* to_root = to->FindRoot();
if (from_root == to_root) {
return;
}

from_root->parent = to_root;
to_root->num_nodes += from_root->num_nodes;

// Update the group_deps_, maintaining the invariant that
// all groups in the map are root groups.
group_deps_[to_root].merge(group_deps_[from_root]);
group_deps_.erase(from_root);
for (auto& it : group_deps_) {
if (it.second.count(from_root)) {
it.second.erase(from_root);
it.second.insert(to_root);
}
}
}

std::unordered_set<Group*> GetParentGroupDependencies(const Array<Expr>& args) {
// Collect groups that parent groups depend on
std::unordered_set<Group*> dependencies;

for (const auto& arg : args) {
for (auto parent_dep : group_deps_[memo_[arg].representative]) {
parent_deps.insert(parent_dep);
for (auto dep : group_deps_[memo_[arg]->FindRoot()]) {
dependencies.insert(dep);
}
}

return dependencies;
}

void UpdateGroupDependencies(Group* group, const Array<Expr>& args) {
Group* group_root = group->FindRoot();

for (const auto& arg : args) {
auto arg_group = memo_[arg];
if (arg_group.target == codegen_name && !parent_deps.count(arg_group.representative)) {
// If there is a parent group with the same target, which none of the parent dependency
// groups depends on, merging "this" call node into the parent group will not form a cyclic
// dependency.
rep = arg_group.representative;
auto arg_group_root = memo_[arg]->FindRoot();
if (arg_group_root == group_root) {
// If arg and the current node are in the same group,
// there is nothing to update.
continue;
}
// Add the group of arg as dependency
group_deps_[group_root].insert(arg_group_root);
// Propagate dependencies of arg
for (auto dep : group_deps_[arg_group_root]) {
group_deps_[group_root].insert(dep);
}
}
}

if (rep == nullptr) {
// If we do not find a valid representative parent group, make a new group.
// This can happen if all arguments are function parameters or belong to other targets.
rep = arena_->make<Group>();
// Set num_nodes to 0 to signify that this representative groups has been newly created.
rep->num_nodes = 0;
rep->attrs.Set(attr::kCodegen, codegen_name);
std::vector<Group*> GetGroupsToMerge(const CallNode* call) {
Optional<String> codegen_name = GetCodegenName(call->op);
if (!codegen_name.defined()) {
return {};
}

// Record direct parent dependencies.
for (const auto& arg : args) {
std::vector<Group*> groups_to_merge;
std::unordered_set<Group*> parent_dependencies = GetParentGroupDependencies(call->args);

for (const auto& arg : call->args) {
auto arg_group = memo_[arg];
if (arg_group.target != codegen_name) {
group_deps_[rep].insert(arg_group.representative);
Optional<String> arg_codegen_name = GetCodegenName(arg_group);
if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) {
// If there is a parent group with the same target, which none of the parent dependency
// groups depends on, merging "this" call node into the parent group will not form a cyclic
// dependency.
groups_to_merge.push_back(arg_group);
}
}

// Propagate parent dependencies.
for (auto parent_dep : parent_deps) {
group_deps_[rep].insert(parent_dep);
}

return rep;
return groups_to_merge;
}

const String kDefaultTarget = "default";
IRModule mod_;
support::Arena* arena_;
CompositeGroup default_group_;
// Map from group to its dependencies. All groups in this map, whether it's
// the key or in value, should be root node (that is, group->parent == nullptr).
std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
};

Expand All @@ -257,6 +288,7 @@ class CompositeInliner : public ExprMutator {
using ExprMutator::VisitExpr_;

Function Run(Function func) {
inlined_functions_ = Map<Function, Function>();
auto new_body = VisitExpr(func->body);
auto new_func =
Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
Expand All @@ -266,16 +298,22 @@ class CompositeInliner : public ExprMutator {
Expr VisitExpr_(const CallNode* call) {
if (call->op->IsInstance<GlobalVarNode>()) {
auto gvar = Downcast<GlobalVar>(call->op);
auto func = CopyWithNewVars(Downcast<Function>(mod_->Lookup(gvar)));
auto func = Downcast<Function>(mod_->Lookup(gvar));

if (func->GetAttr<String>(attr::kComposite)) {
return Call(func, call->args);
if (!inlined_functions_.count(func)) {
inlined_functions_.Set(func, CopyWithNewVars(func));
}
return Call(inlined_functions_[func], call->args);
}
}

return ExprMutator::VisitExpr_(call);
}

private:
IRModule mod_;
Map<Function, Function> inlined_functions_;
};

} // namespace
Expand Down
Loading

0 comments on commit cb1523b

Please sign in to comment.