From 65c2d2b533151b4cf4479b6955791594adc90dcf Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Thu, 8 Dec 2022 15:08:35 +0800 Subject: [PATCH] fp16 (#33) --- python/tvm/relax/cutlass/pattern.py | 163 +++++++++++++++--- python/tvm/relax/transform/mixed_precision.py | 89 ++++++---- python/tvm/relax/transform/op_legalizer.py | 12 +- python/tvm/relax/transform/transform.py | 25 +-- src/relax/transform/split_cutlass.cc | 2 + src/relax/transform/to_mixed_precision.cc | 46 ++--- 6 files changed, 245 insertions(+), 92 deletions(-) diff --git a/python/tvm/relax/cutlass/pattern.py b/python/tvm/relax/cutlass/pattern.py index 25214b9082..c5260681fe 100644 --- a/python/tvm/relax/cutlass/pattern.py +++ b/python/tvm/relax/cutlass/pattern.py @@ -30,6 +30,7 @@ def op_pattern_stitch(evaluated_symbols, evaluated_buffers, matched_pattern_name and matched_pattern_names[1] == "bias_row" and matched_pattern_names[2] == "relu" ): + # dense_row_row_row + bias_row + relu m_dense, n_dense, k_dense = evaluated_symbols[0] m_bias, n_bias = evaluated_symbols[1] m_relu, n_relu = evaluated_symbols[2] @@ -45,12 +46,27 @@ def op_pattern_stitch(evaluated_symbols, evaluated_buffers, matched_pattern_name and C_bias == A_relu ): return matched_pattern_names[:3] + if len(matched_pattern_names) == 2: + assert len(evaluated_symbols) == 2 + assert len(evaluated_buffers) == 2 + if ( + matched_pattern_names[0] == "dense_row_row_row" + and matched_pattern_names[1] == "bias_row" + ): + # dense_row_row_row + bias_row + m_dense, n_dense, k_dense = evaluated_symbols[0] + m_bias, n_bias = evaluated_symbols[1] + A_dense, B_dense, C_dense = evaluated_buffers[0] + A_bias, B_bias, C_bias = evaluated_buffers[1] + if m_dense == m_bias and n_dense == n_bias and C_dense == A_bias: + return matched_pattern_names[:2] if len(matched_pattern_names) == 1: assert len(evaluated_symbols) == 1 assert len(evaluated_buffers) == 1 if matched_pattern_names[0] == "dense_row_row_row": + # dense_row_row_row return matched_pattern_names[:1] - return 0 + return [] A_TYPE = "float16" @@ -77,12 +93,9 @@ def dense_row_row_row(): with I.ir_module() as frame: with T.prim_func(): T.func_name("dense_row_row_row") - A = T.arg("A", T.buffer_decl((m, k), A_TYPE) - ) # pylint: disable=invalid-name - B = T.arg("B", T.buffer_decl((k, n), B_TYPE) - ) # pylint: disable=invalid-name - C = T.arg("C", T.buffer_decl((m, n), C_TYPE) - ) # pylint: disable=invalid-name + A = T.arg("A", T.buffer_decl((m, k), A_TYPE)) # pylint: disable=invalid-name + B = T.arg("B", T.buffer_decl((k, n), B_TYPE)) # pylint: disable=invalid-name + C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name with T.grid(m, n, k) as (l0, l1, l2): with T.block("dense_row_row_row"): vi, vj, vk = T.axis.remap("SSR", [l0, l1, l2]) @@ -90,8 +103,7 @@ def dense_row_row_row(): T.writes(C[vi, vj]) with T.init(): T.buffer_store(C, T.cast(0.0, C_TYPE), [vi, vj]) - T.buffer_store( - C, C[vi, vj] + A[vi, vk] * B[vk, vj], [vi, vj]) + T.buffer_store(C, C[vi, vj] + A[vi, vk] * B[vk, vj], [vi, vj]) return ib.get()["dense_row_row_row"] @@ -103,12 +115,9 @@ def bias_row(): with I.ir_module() as frame: with T.prim_func(): T.func_name("bias_row") - A = T.arg("A", T.buffer_decl((m, n), A_TYPE) - ) # pylint: disable=invalid-name - B = T.arg("B", T.buffer_decl((0, n), B_TYPE) - ) # pylint: disable=invalid-name - C = T.arg("C", T.buffer_decl((m, n), C_TYPE) - ) # pylint: disable=invalid-name + A = T.arg("A", T.buffer_decl((m, n), A_TYPE)) # pylint: disable=invalid-name + B = T.arg("B", T.buffer_decl((0, n), B_TYPE)) # pylint: disable=invalid-name + C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name with T.grid(m, n) as (l0, l1): with T.block("bias_row"): i, j = T.axis.remap("SS", [l0, l1]) @@ -126,27 +135,29 @@ def relu(): with I.ir_module() as frame: with T.prim_func(): T.func_name("relu") - A = T.arg("A", T.buffer_decl((m, n), A_TYPE) - ) # pylint: disable=invalid-name - C = T.arg("C", T.buffer_decl((m, n), C_TYPE) - ) # pylint: disable=invalid-name + A = T.arg("A", T.buffer_decl((m, n), A_TYPE)) # pylint: disable=invalid-name + C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name with T.grid(m, n) as (l0, l1): with T.block("relu"): i, j = T.axis.remap("SS", [l0, l1]) T.reads(A[i, j]) T.writes(C[i, j]) - T.buffer_store( - C, T.max(A[i, j], T.cast(0, A_TYPE)), [i, j]) + T.buffer_store(C, T.max(A[i, j], T.cast(0, A_TYPE)), [i, j]) return ib.get()["relu"] @register_func("tvm.relax.cutlass.get_graph_pattern_code") def get_graph_pattern_code(cutlass_op): cutlass_op = [str(st) for st in cutlass_op] + pattern = "/".join(cutlass_op) + if pattern not in GRAPH_PATTERN_CODE_LIST: + raise tvm.TVMError("Cannot find graph pattern code for cutlass op: {}".format(cutlass_op)) return GRAPH_PATTERN_CODE_LIST["/".join(cutlass_op)] -GRAPH_PATTERN_CODE_LIST["dense_row_row_row"] = """ +GRAPH_PATTERN_CODE_LIST[ + "dense_row_row_row" +] = """ #define CUTLASS_ENABLE_CUBLAS 1 #define CUTLASS_NAMESPACE cutlass #define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 @@ -219,8 +230,114 @@ def get_graph_pattern_code(cutlass_op): TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _GEMM); """ +GRAPH_PATTERN_CODE_LIST[ + "dense_row_row_row/bias_row" +] = """ + #define CUTLASS_ENABLE_CUBLAS 1 + #define CUTLASS_NAMESPACE cutlass + #define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 + #define NDEBUG + + #include + #include + #include + #include + + #include + #include + #include + #include + + #define DMLC_USE_LOGGING_LIBRARY + + #include + #include + #include + + namespace { + + using namespace tvm; + using namespace tvm::runtime; + + // simple specialized impl, can be replaced by + // call into libraries. + void _HGEMM_BIAS(NDArray A, NDArray B, NDArray Bias, NDArray C) { + // A: [M, K], B: [K, N], BIAS: [1, N], C: [M, N] + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(Bias->ndim, 2); + CHECK_EQ(C->ndim, 2); + CHECK_EQ(A->shape[1], B->shape[0]); + int M = A->shape[0]; + int K = A->shape[1]; + int N = B->shape[1]; + CHECK_EQ(C->shape[0], M); + CHECK_EQ(C->shape[1], N); + CHECK_EQ(Bias->shape[0], 1); + CHECK_EQ(Bias->shape[1], N); + CHECK_EQ(A.DataType(), DataType::Float(16)); + CHECK_EQ(B.DataType(), DataType::Float(16)); + CHECK_EQ(Bias.DataType(), DataType::Float(16)); + CHECK_EQ(C.DataType(), DataType::Float(16)); + + // Define the GEMM operation + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + + cutlass::epilogue::thread::LinearCombinationBias< + cutlass::half_t, + 8, + cutlass::half_t, + cutlass::half_t, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + false, + cutlass::arch::OpMultiplyAdd + >; + + Gemm gemm_op; + + cutlass::half_t alpha(1.0); + cutlass::half_t beta(0.0); + cutlass::layout::ColumnMajor::Stride::Index lda(K); + cutlass::layout::ColumnMajor::Stride::Index ldb(N); + cutlass::layout::ColumnMajor::Stride::Index ld_bias(N); + cutlass::layout::ColumnMajor::Stride::Index ldc(N); + cutlass::half_t* a = reinterpret_cast(A->data); + cutlass::half_t* b = reinterpret_cast(B->data); + cutlass::half_t* c = reinterpret_cast(C->data); + cutlass::half_t* bias = reinterpret_cast(Bias->data); + + cutlass::Status status = gemm_op({ + {M, N, K}, // GemmCoord problem_size_ + {a, lda}, // TensorRef ref_A_ + {b, ldb}, // TensorRef ref_B_ + {bias, ld_bias}, // TensorRef ref_C_ + {c, ldc}, // TensorRef ref_D_ + {alpha, beta} // typename EpilogueOutputOp::Params epilogue_ + }); + CHECK(status == cutlass::Status::kSuccess); + } + + } // namespace + TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _HGEMM_BIAS); +""" -GRAPH_PATTERN_CODE_LIST["dense_row_row_row/bias_row/relu"] = """ +GRAPH_PATTERN_CODE_LIST[ + "dense_row_row_row/bias_row/relu" +] = """ #define CUTLASS_ENABLE_CUBLAS 1 #define CUTLASS_NAMESPACE cutlass #define CUTLASS_ENABLE_TENSOR_CORE_MMA 1 diff --git a/python/tvm/relax/transform/mixed_precision.py b/python/tvm/relax/transform/mixed_precision.py index 00c9f32a7b..33af1b450f 100644 --- a/python/tvm/relax/transform/mixed_precision.py +++ b/python/tvm/relax/transform/mixed_precision.py @@ -16,8 +16,11 @@ # under the License. # pylint: disable=line-too-long,unused-argument """Default behavior for ops in mixed_precision pass. Import this file to use.""" +import copy from typing import List +from tvm.relay import Call +from tvm.relax.op import dense, matmul, conv2d from tvm.relay.op import register_mixed_precision_conversion # MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory @@ -31,11 +34,7 @@ # Default lists inspired from TF's classifications: # github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h # They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. -DEFAULT_ALWAYS_LIST = [ - "relax.nn.dense", - "relax.nn.conv2d", - "relax.nn.matmul" -] +DEFAULT_ALWAYS_LIST = ["relax.nn.dense", "relax.nn.conv2d", "relax.nn.matmul"] DEFAULT_FOLLOW_LIST = [ "relax.nn.flatten", "relax.nn.batch_norm", @@ -63,12 +62,7 @@ "relax.cast", "relax.broadcast_to", ] -DEFAULT_NEVER_LIST = [ - "relax.nn.softmax", - "relax.nn.layer_norm", - "relax.sum", - "relax.mean" -] +DEFAULT_NEVER_LIST = ["relax.nn.softmax", "relax.nn.layer_norm", "relax.sum", "relax.mean"] # Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType @@ -81,49 +75,74 @@ def decorator(func): return decorator -def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]: +def get_generic_out_dtypes(call_node: "relay.Call", expected_out_dtype: str) -> List: """A function which returns output dtypes in a way which works for most ops. Parameters --------- call_node: relay.Call The call node containing the op. - mixed_precision_type: str - The target type to run the operation in. + + expected_out_dtype: str + The output dtype to use. + Returns ------- - output_dtypes : [str, str] - A list of two strings. The first represents the datatype used for accumulation - in the operation. The second represents the actual output datatype. + output_dtypes : [str] + A list of output dtype. """ - # Assume support accumulation dtypes <---> has out_dtype attr. - # This is because there is no better way right now to tell which ops support accumulating - # at different data types. - # Some discussion here about making this better is here: - # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo if hasattr(call_node.attrs, "out_dtype"): - # TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators - # return ["float32", mixed_precision_type] - out_dtype = "float32" if call_node.attrs.out_dtype == "" else call_node.attrs.out_dtype - return [out_dtype, mixed_precision_type] - - # [accumulation_dtype, output_dtype] for the operations - return [mixed_precision_type, mixed_precision_type] + if call_node.op.name == "relax.nn.dense": + adjust_call = dense( + call_node.args[0], + call_node.args[1], + units=call_node.attrs.units, + out_dtype=expected_out_dtype, + ) + elif call_node.op.name == "relax.nn.conv2d": + adjust_call = conv2d( + call_node.args[0], + call_node.args[1], + channels=call_node.attrs.channels, + kernel_size=call_node.attrs.kernel_size, + strides=call_node.attrs.strides, + padding=call_node.attrs.padding, + dilation=call_node.attrs.dilation, + groups=call_node.attrs.groups, + data_layout=call_node.attrs.data_layout, + kernel_layout=call_node.attrs.kernel_layout, + out_layout=call_node.attrs.out_layout, + out_dtype=expected_out_dtype, + ) + elif call_node.op.name == "relax.nn.matmul": + adjust_call = matmul( + call_node.args[0], + call_node.args[1], + out_dtype=expected_out_dtype, + ) + else: + raise ValueError("Unsupported op for get_generic_out_dtypes", call_node.op.name) + return [ + True, + adjust_call, + ] + else: + return [False, call_node] # Functions for FTVMMixedPrecisionConversionType which # Take in CallNodes and a DType and returns a conversion type, # an accumulation dtype, and an output_dtype. @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) -def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List: - return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) +def generic_always_op(call_node: "relay.Call", expected_out_dtype: str) -> List: + return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, expected_out_dtype) @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) -def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List: - return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) +def generic_follow_op(call_node: "relay.Call", expected_out_dtype: str) -> List: + return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, expected_out_dtype) @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) -def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List: - return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) +def generic_never_op(call_node: "relay.Call", expected_out_dtype: str) -> List: + return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, expected_out_dtype) diff --git a/python/tvm/relax/transform/op_legalizer.py b/python/tvm/relax/transform/op_legalizer.py index ca86a4d829..7906cfc696 100644 --- a/python/tvm/relax/transform/op_legalizer.py +++ b/python/tvm/relax/transform/op_legalizer.py @@ -21,6 +21,7 @@ from tvm import ir, te, topi from tvm.ir import Attrs from tvm.ir.module import IRModule +from tvm.tir.generic import cast from ..analysis import remove_all_unused from ..expr import Call, Expr, Function, Tuple, TupleGetItem @@ -80,11 +81,18 @@ def _nn_relu(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Exp def _nn_gelu(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): def gelu(x): + dtype = x.dtype return te.compute( x.shape, - lambda *i: 0.5 + lambda *i: cast(0.5, dtype) * x(*i) - * (1 + te.tanh(math.sqrt(2 / math.pi) * (x(*i) + 0.044715 * te.power(x(*i), 3)))), + * ( + cast(1, dtype) + + te.tanh( + cast(math.sqrt(2) / math.pi, dtype) + * (x(*i) + cast(0.044715, dtype) * te.power(x(*i), 3)) + ) + ), ) return bb.call_te(gelu, args[0]) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 60a95b6222..bc21bda882 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -354,8 +354,8 @@ def CutlassCodegen() -> tvm.ir.transform.Pass: def SplitCutlass() -> tvm.ir.transform.Pass: - """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is - matched with some cutlass kernels, and the second part is the rest of the original + """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is + matched with some cutlass kernels, and the second part is the rest of the original PrimFunc that is not fused with cutlass kernels. Returns @@ -366,15 +366,20 @@ def SplitCutlass() -> tvm.ir.transform.Pass: return _ffi_api.SplitCutlass() -def ToMixedPrecision() -> tvm.ir.transform.Pass: +def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass: """Automatic mixed precision pass. + Parameters + ---------- + out_dtype : str + The output data type of gemm/conv + Returns ------- ret : tvm.transform.Pass The registered pass for mixed precision. """ - return _ffi_api.ToMixedPrecision() + return _ffi_api.ToMixedPrecision(out_dtype) def _wrap_class_function_pass(pass_cls, pass_info): @@ -508,8 +513,7 @@ def transform(func, mod, ctx): required = required if required else [] if not isinstance(required, (list, tuple)): - raise TypeError( - "Required is expected to be the type of " + "list/tuple.") + raise TypeError("Required is expected to be the type of " + "list/tuple.") def create_function_pass(pass_arg): """Internal function that creates a function pass""" @@ -658,13 +662,11 @@ def transform(block, mod, ctx): """ if opt_level is None: - raise ValueError( - "Please provide opt_level for the dataflowblock pass.") + raise ValueError("Please provide opt_level for the dataflowblock pass.") required = required if required else [] if not isinstance(required, (list, tuple)): - raise TypeError( - "Required is expected to be the type of " + "list/tuple.") + raise TypeError("Required is expected to be the type of " + "list/tuple.") def create_dataflowblock_pass(pass_arg): """Internal function that creates a dataflowblock pass""" @@ -673,8 +675,7 @@ def create_dataflowblock_pass(pass_arg): if inspect.isclass(pass_arg): return _wrap_class_dataflowblock_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): - raise TypeError( - "pass_func must be a callable for DataflowBlock pass") + raise TypeError("pass_func must be a callable for DataflowBlock pass") return _ffi_api.MakeDataflowBlockPass(pass_arg, info) # type: ignore if pass_func: diff --git a/src/relax/transform/split_cutlass.cc b/src/relax/transform/split_cutlass.cc index 00b1fc4290..47e7d86055 100644 --- a/src/relax/transform/split_cutlass.cc +++ b/src/relax/transform/split_cutlass.cc @@ -378,6 +378,7 @@ std::pair> SplitFunctions( } } if (!has_second_func) { + func = WithAttr(func, "cutlass_codegen", Bool(true)); return {WithAttr(func, "cutlass_kernel", matcher.cutlass_annotation), NullOpt}; } // Step 2. Split the function into two functions. @@ -406,6 +407,7 @@ std::pair> SplitFunctions( new_buffer_map1.Set(new_params1.back(), matcher.intermediate_buffer); PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, new_buffer_map1, func->attrs); func1 = WithAttr(func1, "cutlass_kernel", matcher.cutlass_annotation); + func1 = WithAttr(func1, "cutlass_codegen", Bool(true)); // Step 4. Craft the second function. Array new_params2; std::vector arg_partition2; diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index c0aba534c2..2802c33792 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include "../op/make_op.h" @@ -35,14 +36,16 @@ enum MixedTypeConversionCategory : int { MIXED_PRECISION_NEVER = 2 }; -// Return array is of type : [MixedTypeConversionCategory (int), String, String] -// The fields are : [ConversionCategory, accumulation_datatype, output_datatype] -// Call is a call node, DataType is the mixed precision type +// Return array is of type : [MixedTypeConversionCategory (int), bool, Call] +// Call is a call node, out_dtype_str is the expected output_dtype string using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( - const Call& call_node, const std::string& target_dtype_str)>; + const Call& call_node, const std::string& out_dtype_str)>; class ToMixedPrecisionMutator : public ExprMutator { public: + explicit ToMixedPrecisionMutator(DLDataType output_dtype) + : expected_output_dtype_(output_dtype) {} + void InitVarMap(const relax::Function& func) { for (const auto& param : func->params) { if (const auto* type = param->checked_type_.as()) { @@ -112,7 +115,7 @@ class ToMixedPrecisionMutator : public ExprMutator { if (attr_map.count(op)) { FTVMMixedPrecisionConversionType func = attr_map[op]; Array op_descriptor = - func(GetRef(call_node), DLDataType2String(low_precision_type_)); + func(GetRef(call_node), DLDataType2String(expected_output_dtype_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() << ") from FTVMMixedPrecisionConversionType for " @@ -121,21 +124,23 @@ class ToMixedPrecisionMutator : public ExprMutator { int64_t op_conversion_type = Downcast(op_descriptor[0])->value; MixedTypeConversionCategory category = static_cast(op_conversion_type); - DataType accumulation_dtype = - DataType(String2DLDataType(Downcast(op_descriptor[1]))); + bool out_dtype_adjustable = Downcast(op_descriptor[1])->value; + Call adjusted_call = Downcast(op_descriptor[2]); if (category == MIXED_PRECISION_ALWAYS) { // LOG(INFO) << "MIXED_PRECISION_ALWAYS"; // Cast inputs to fp16 std::vector new_args; CastArgsToType(call_node->args, low_precision_type_, &new_args); - // Cast output according to out_dtype (if necessary) - if (accumulation_dtype != low_precision_type_) { - // LOG(INFO) << "RECAST"; - relax::Var accmulate = - emit(relax::Call(call_node->op, new_args, call_node->attrs, call_node->type_args), - binding->var); - relax::Var cast_back = - emit(relax::MakeCast(accmulate, low_precision_type_), binding->var); + if (out_dtype_adjustable) { + // Cast output according to out_dtype + relax::Var accmulate = emit( + relax::Call(call_node->op, new_args, adjusted_call->attrs, call_node->type_args), + binding->var); + if (expected_output_dtype_ != low_precision_type_) { + // LOG(INFO) << "RECAST"; + relax::Var cast_back = + emit(relax::MakeCast(accmulate, low_precision_type_), binding->var); + } return; } else { relax::Var new_var = @@ -293,25 +298,26 @@ class ToMixedPrecisionMutator : public ExprMutator { ObjectPtrEqual> var_map_; std::unordered_map const_map_; + DataType expected_output_dtype_; }; // namespace relax -Expr ToMixedPrecision(const relax::Function& f) { - ToMixedPrecisionMutator mutator; +Expr ToMixedPrecision(const relax::Function& f, const runtime::String& out_dtype) { + ToMixedPrecisionMutator mutator(runtime::String2DLDataType(out_dtype)); mutator.InitVarMap(f); return mutator.VisitExpr(f); } namespace transform { -Pass ToMixedPrecision() { +Pass ToMixedPrecisionPass(const runtime::String& out_dtype) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToMixedPrecision(f)); + return Downcast(ToMixedPrecision(f, out_dtype)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); +TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecisionPass); } // namespace transform