Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[BYOC] Update TensorRT backend for the new BYOC flow and offloading w…
Browse files Browse the repository at this point in the history
…ith constants (#400)

* Update TensorRT backend for the new BYOC flow and offloading with constant

* update cutlass codegen for the new signature

* add comment

* fix pattern partitioning for residual block

* update cutlass rev and enable disabled test
  • Loading branch information
masahi authored Feb 7, 2023
1 parent c09886d commit 83adb87
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 182 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,15 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
new_mod = seq(mod)

# Extract external runtime modules if exist.
ext_libs = []
if mod.attrs and "external_mods" in mod.attrs:
ext_libs = mod.attrs["external_mods"]
attrs = dict(mod.attrs) if mod.attrs else {}

ext_libs = attrs.get("external_mods", [])
constants = attrs.get("const_name_to_constant", {})

if params is not None:
params.update(dict(constants))
else:
params = constants

# builder collects the executable
builder = relax.ExecBuilder()
Expand Down
27 changes: 14 additions & 13 deletions src/relax/backend/contrib/codegen_json/codegen_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,10 @@ class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {

/*!
* \brief Constructor
*
* \param symbol The symbol that represents the graph being converted.
* \param expr The Relax expression to be converted to the JSON form.
* \param constant_names The names of all constants in the original module.
*/
explicit JSONSerializer(const std::string& symbol) : symbol_(symbol) {}
explicit JSONSerializer(const Map<Constant, String>& constant_names)
: constant_names_(constant_names) {}

void serialize(Function func) {
// First we convert all the parameters into input nodes.
Expand All @@ -168,8 +167,8 @@ class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {
heads_ = VisitExpr(func->body);
}

/*!\brief Return the required params. */
Array<String> GetParams() const { return params_; }
/*!\brief Return the required constants. */
Array<String> GetConstantNames() const { return constants_used_; }

/*!\brief Return the generated json. */
std::string GetJSON() {
Expand Down Expand Up @@ -320,9 +319,11 @@ class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {
}

NodeEntries VisitExpr_(const ConstantNode* cn) {
std::string name = symbol_ + "_const_" + std::to_string(params_.size());
params_.push_back(name);
auto node = std::make_shared<JSONGraphNode>(name, "const" /* op_type_ */);
auto name = constant_names_.find(GetRef<Constant>(cn));
ICHECK(name != constant_names_.end())
<< "Cannot find the name of the constant: " << GetRef<Constant>(cn);
constants_used_.push_back((*name).second);
auto node = std::make_shared<JSONGraphNode>((*name).second, "const" /* op_type_ */);
return AddNode(node, GetRef<Expr>(cn));
}

Expand Down Expand Up @@ -405,14 +406,14 @@ class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {
}

private:
/*! \brief The symbol that represents the json graph. */
std::string symbol_;
/*! \brief JSON graph nodes. */
std::vector<JSONGraphObjectPtr> nodes_;
/*! \brief Output of the JSON graph. */
NodeEntries heads_;
/*! \brief The list of required constants. */
Array<String> params_;
/*! \brief The list of required constants, ordered. */
Array<String> constants_used_;
/*! \brief The names of all constants in the original module. */
const Map<Constant, String>& constant_names_;
};

} // namespace contrib
Expand Down
3 changes: 2 additions & 1 deletion src/relax/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class CutlassModuleCodegen {
Array<String> func_names_;
};

Array<runtime::Module> CUTLASSCompiler(Array<Function> functions, Map<String, ObjectRef> options) {
Array<runtime::Module> CUTLASSCompiler(Array<Function> functions, Map<String, ObjectRef> options,
Map<Constant, String> /*unused*/) {
const auto* tune_func = runtime::Registry::Get("contrib.cutlass.tune_relax_function");
ICHECK(tune_func != nullptr)
<< "The packed function contrib.cutlass.tune_relax_function not found, "
Expand Down
15 changes: 8 additions & 7 deletions src/relax/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ using backend::contrib::NodeEntries;

class DNNLJSONSerializer : public JSONSerializer {
public:
DNNLJSONSerializer(const std::string& symbol, const Map<Var, Expr>& bindings)
: JSONSerializer(symbol), bindings_(bindings) {}
DNNLJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> bindings)
: JSONSerializer(constant_names), bindings_(bindings) {}

using JSONSerializer::VisitExpr_;

Expand Down Expand Up @@ -80,18 +80,19 @@ class DNNLJSONSerializer : public JSONSerializer {
Map<Var, Expr> bindings_;
};

Array<runtime::Module> DNNLCompiler(Array<Function> functions, Map<String, ObjectRef> /*unused*/) {
Array<runtime::Module> DNNLCompiler(Array<Function> functions, Map<String, ObjectRef> /*unused*/,
Map<Constant, String> constant_names) {
Array<runtime::Module> compiled_functions;

for (const auto& func : functions) {
auto func_name = GetExtSymbol(func);
DNNLJSONSerializer serializer(func_name, AnalyzeVar2Value(func));
DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
serializer.serialize(func);
auto graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();
auto constant_names = serializer.GetConstantNames();
const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate");
ICHECK(pf != nullptr) << "Cannot find DNNL runtime module create function.";
compiled_functions.push_back((*pf)(func_name, graph_json, param_names));
auto func_name = GetExtSymbol(func);
compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
}

return compiled_functions;
Expand Down
152 changes: 24 additions & 128 deletions src/relax/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,107 +91,13 @@ class CollectFromCompositeFunctionBody : public ExprVisitor {

void VisitExpr_(const ConstantNode* constant_node) final;
void VisitExpr_(const CallNode* call_node) final;
/*
void SetPadNodeAttribute(const CallNode* call_node) {
const auto* pad_attr = call_node->attrs.as<PadAttrs>();
ICHECK(pad_attr);
auto p = pad_attr->pad_width;
const int dim_h = (p.size() == 5) ? 3 : 2;
const int dim_w = (p.size() == 5) ? 4 : 3;
std::vector<std::string> padding = {std::to_string(p[dim_h][0].as<IntImmNode>()->value),
std::to_string(p[dim_w][0].as<IntImmNode>()->value),
std::to_string(p[dim_h][1].as<IntImmNode>()->value),
std::to_string(p[dim_w][1].as<IntImmNode>()->value)};
std::vector<dmlc::any> padding_attr;
padding_attr.emplace_back(padding);
node_->SetAttr("padding", padding_attr);
}
void SetStridedSliceNodeAttribute(const CallNode* call_node) {
const auto* attrs = call_node->attrs.as<StridedSliceAttrs>();
ICHECK(attrs && attrs->begin && attrs->end && attrs->strides)
<< "StridedSlice must have static begin, end, and strides.";
const bool default_strides =
!attrs->strides.value().defined() || attrs->strides.value().size() == 0;
auto ishape = backend::GetShape(call_node->args[0]->checked_type());
auto process_slice_index = [](Integer x, int default_value, int dim_value) {
if (!x.defined()) return default_value;
int value = x.as<IntImmNode>()->value;
if (value < 0) value += dim_value;
return value;
};
std::vector<std::string> start, size, strides;
for (size_t i = 0; i < attrs->begin.value().size(); ++i) {
const int begin_value = process_slice_index(attrs->begin.value()[i], 0, ishape[i]);
ICHECK_GE(begin_value, 0);
start.push_back(std::to_string(begin_value));
const int stride_value = (default_strides || i >= attrs->strides.value().size() ||
!attrs->strides.value()[i].defined())
? 1
: attrs->strides.value()[i].as<IntImmNode>()->value;
ICHECK_GT(stride_value, 0);
strides.push_back(std::to_string(stride_value));
int size_value;
if (attrs->slice_mode == "end") {
const int end_value = process_slice_index(attrs->end.value()[i], ishape[i], ishape[i]);
size_value = (end_value - begin_value + stride_value - 1) / stride_value;
} else if (attrs->slice_mode == "size") {
// with slice_mode = "size", attrs->end_value mean the size of the slice
int end_value = attrs->end.value()[i].as<IntImmNode>()->value;
size_value = (end_value == -1) ? ishape[i] - begin_value : end_value;
} else {
LOG(FATAL) << "Unexpected slice_mode " << attrs->slice_mode << ", expected end or size";
throw;
}
ICHECK_GT(size_value, 0);
size.push_back(std::to_string(size_value));
}
std::vector<dmlc::any> start_attr, size_attr, strides_attr;
start_attr.emplace_back(start);
size_attr.emplace_back(size);
strides_attr.emplace_back(strides);
node_->SetAttr("start", start_attr);
node_->SetAttr("size", size_attr);
node_->SetAttr("strides", strides_attr);
}

void SetSplitNodeAttribute(const CallNode* call_node) {
const auto* split_attr = call_node->attrs.as<SplitAttrs>();
ICHECK(split_attr);
std::vector<std::string> indices_or_sections;
std::vector<std::string> mode;
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
if (const auto* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
mode.emplace_back("sections");
indices_or_sections.emplace_back(std::to_string(sections->value));
} else {
mode.emplace_back("indices");
auto indices = Downcast<tvm::Array<Integer>>(split_attr->indices_or_sections);
for (const auto& i : indices) {
indices_or_sections.emplace_back(std::to_string(i->value));
}
}
std::vector<dmlc::any> indices_or_sections_attr;
std::vector<dmlc::any> mode_attr;
std::vector<dmlc::any> axis_attr;
indices_or_sections_attr.emplace_back(indices_or_sections);
mode_attr.emplace_back(mode);
axis_attr.emplace_back(axis);
node_->SetAttr("indices_or_sections", indices_or_sections_attr);
node_->SetAttr("mode", mode_attr);
node_->SetAttr("axis", axis_attr);
}
void SetGenericAttributes(const CallNode* call_node) {
OpAttrExtractor extractor(node_);
const Object* attr_obj = call_node->attrs.get();
extractor.Extract(const_cast<Object*>(attr_obj));
}

void SetGenericAttributes(const CallNode* call_node) {
OpAttrExtractor extractor(node_);
const Object* attr_obj = call_node->attrs.get();
extractor.Extract(const_cast<Object*>(attr_obj));
}
*/
TensorRTJSONSerializer* serializer_;
/*! \brief Accumulated translated arguments. */
std::vector<JSONGraphNodeEntry> args_;
Expand All @@ -209,23 +115,24 @@ class CollectFromCompositeFunctionBody : public ExprVisitor {
*/
class TensorRTJSONSerializer : public JSONSerializer {
public:
explicit TensorRTJSONSerializer(const std::string& symbol) : JSONSerializer(symbol) {}
explicit TensorRTJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> bindings)
: JSONSerializer(constant_names), bindings_(bindings) {}

using JSONSerializer::VisitExpr_;

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
// The call must be to an inline "Composite" function
const auto* function_node = call_node->op.as<FunctionNode>();
// ICHECK(function_node != nullptr);
if (!function_node) return JSONSerializer::VisitExpr_(call_node);
const auto* fn_var = call_node->op.as<VarNode>();
ICHECK(fn_var);
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);

auto opt_composite = function_node->GetAttr<String>(attr::kComposite);
auto opt_composite = fn->GetAttr<String>(attr::kComposite);
ICHECK(opt_composite.defined());
std::string name = opt_composite.value();

// Collect the constants and attributes of all operator calls inside the composite body.
CollectFromCompositeFunctionBody collector(this);
collector.VisitExpr(function_node->body);
collector.VisitExpr(fn->body);

// Capture the args to the "Composite" function as inputs for this node.
std::vector<JSONGraphNodeEntry> inputs;
Expand All @@ -238,9 +145,7 @@ class TensorRTJSONSerializer : public JSONSerializer {
for (const auto& node : collector.args_) {
inputs.emplace_back(node);
}
// TODO(@sunggg): Revisit when we have op naming convention.
// Currently, simply remove "relax." prefix to make it work.
name = std::string("tensorrt.") + name.substr(6);

// Create the final node.
auto node = std::make_shared<JSONGraphNode>(name,
/*op_type=*/"kernel", inputs,
Expand Down Expand Up @@ -285,6 +190,10 @@ class TensorRTJSONSerializer : public JSONSerializer {
node->SetAttr("use_fp16", use_fp16_attr);
node->SetAttr("use_uint8", use_uint8_attr);
}

private:
/*! \brief The bindings to look up composite functions. */
Map<Var, Expr> bindings_;
};

void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) {
Expand All @@ -294,21 +203,7 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_n
}

void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
const auto* op_node = call_node->op.as<OpNode>();
ICHECK(op_node != nullptr);
std::string name = op_node->name;
/*
// TODO(@sunggg): revisit when relax supports these ops.
if (name == "nn.pad") {
SetPadNodeAttribute(call_node);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(call_node);
} else if (name == "split") {
SetSplitNodeAttribute(call_node);
} else {
SetGenericAttributes(call_node);
}
*/
SetGenericAttributes(call_node);
ExprVisitor::VisitExpr_(call_node);
}

Expand All @@ -318,20 +213,21 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
* \return Runtime modules.
*/
Array<runtime::Module> TensorRTCompiler(Array<Function> functions,
Map<String, ObjectRef> /*unused*/) {
Map<String, ObjectRef> /*unused*/,
Map<Constant, String> constant_names) {
Array<runtime::Module> compiled_functions;
for (const auto& func : functions) {
std::string func_name = GetExtSymbol(func);
VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
TensorRTJSONSerializer serializer(func_name);
TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
serializer.serialize(func);
std::string graph_json = serializer.GetJSON();
VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
auto param_names = serializer.GetParams();
auto constant_names = serializer.GetConstantNames();
const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function.";
std::string func_name = GetExtSymbol(func);
VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
compiled_functions.push_back((*pf)(func_name, graph_json, param_names));
compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
}
return compiled_functions;
}
Expand Down
4 changes: 2 additions & 2 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib, Array<Mo
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
std::unordered_map<std::string, runtime::NDArray> conv_params;
for (const auto& kv : params) {
conv_params[kv.first] = kv.second;
for (const auto& [name, param] : params) {
conv_params[name] = param;
}
Module combined_lib = codegen::CreateMetadataModule(
conv_params, lib.value(), ext_libs, target,
Expand Down
15 changes: 4 additions & 11 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,18 +949,11 @@ class PatternBasedPartitioner : ExprVisitor {

for (const auto& [pat, match] : matches_opt.value()) {
ICHECK(group_map_.count(match.get()));
// The op node itself is also a part of the matched expressions, but it can be ignored.
if (!match->IsInstance<OpNode>()) {
// Put all matching expressions into the parent group.
// Put all matching call nodes into the parent group.
if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call)) {
AddToGroup(match, parent_group);
if (match != GetRef<Call>(call) && !pat->IsInstance<WildcardPatternNode>()) {
// In the example above, we hit this code path when "match" is the conv2d call node
// on the RHS.
// After we put the conv2d into the parent group, "conv1", we also need to put "lv"
// on the LHS into the same parent group. We need this additional handling because
// "lv" does not appear as part of the matched expressions.
AddToGroup(value_to_bound_var_[match], parent_group);
}
// Put the bound variable on the LHS into the same parent group.
AddToGroup(value_to_bound_var_[match], parent_group);
}
}
}
Expand Down
Loading

0 comments on commit 83adb87

Please sign in to comment.