Skip to content

Commit

Permalink
fp16 (tlc-pack#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and Hzfengsy committed Dec 8, 2022
1 parent 238dadb commit 65c2d2b
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 92 deletions.
163 changes: 140 additions & 23 deletions python/tvm/relax/cutlass/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def op_pattern_stitch(evaluated_symbols, evaluated_buffers, matched_pattern_name
and matched_pattern_names[1] == "bias_row"
and matched_pattern_names[2] == "relu"
):
# dense_row_row_row + bias_row + relu
m_dense, n_dense, k_dense = evaluated_symbols[0]
m_bias, n_bias = evaluated_symbols[1]
m_relu, n_relu = evaluated_symbols[2]
Expand All @@ -45,12 +46,27 @@ def op_pattern_stitch(evaluated_symbols, evaluated_buffers, matched_pattern_name
and C_bias == A_relu
):
return matched_pattern_names[:3]
if len(matched_pattern_names) == 2:
assert len(evaluated_symbols) == 2
assert len(evaluated_buffers) == 2
if (
matched_pattern_names[0] == "dense_row_row_row"
and matched_pattern_names[1] == "bias_row"
):
# dense_row_row_row + bias_row
m_dense, n_dense, k_dense = evaluated_symbols[0]
m_bias, n_bias = evaluated_symbols[1]
A_dense, B_dense, C_dense = evaluated_buffers[0]
A_bias, B_bias, C_bias = evaluated_buffers[1]
if m_dense == m_bias and n_dense == n_bias and C_dense == A_bias:
return matched_pattern_names[:2]
if len(matched_pattern_names) == 1:
assert len(evaluated_symbols) == 1
assert len(evaluated_buffers) == 1
if matched_pattern_names[0] == "dense_row_row_row":
# dense_row_row_row
return matched_pattern_names[:1]
return 0
return []


A_TYPE = "float16"
Expand All @@ -77,21 +93,17 @@ def dense_row_row_row():
with I.ir_module() as frame:
with T.prim_func():
T.func_name("dense_row_row_row")
A = T.arg("A", T.buffer_decl((m, k), A_TYPE)
) # pylint: disable=invalid-name
B = T.arg("B", T.buffer_decl((k, n), B_TYPE)
) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)
) # pylint: disable=invalid-name
A = T.arg("A", T.buffer_decl((m, k), A_TYPE)) # pylint: disable=invalid-name
B = T.arg("B", T.buffer_decl((k, n), B_TYPE)) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name
with T.grid(m, n, k) as (l0, l1, l2):
with T.block("dense_row_row_row"):
vi, vj, vk = T.axis.remap("SSR", [l0, l1, l2])
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
T.buffer_store(C, T.cast(0.0, C_TYPE), [vi, vj])
T.buffer_store(
C, C[vi, vj] + A[vi, vk] * B[vk, vj], [vi, vj])
T.buffer_store(C, C[vi, vj] + A[vi, vk] * B[vk, vj], [vi, vj])
return ib.get()["dense_row_row_row"]


Expand All @@ -103,12 +115,9 @@ def bias_row():
with I.ir_module() as frame:
with T.prim_func():
T.func_name("bias_row")
A = T.arg("A", T.buffer_decl((m, n), A_TYPE)
) # pylint: disable=invalid-name
B = T.arg("B", T.buffer_decl((0, n), B_TYPE)
) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)
) # pylint: disable=invalid-name
A = T.arg("A", T.buffer_decl((m, n), A_TYPE)) # pylint: disable=invalid-name
B = T.arg("B", T.buffer_decl((0, n), B_TYPE)) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name
with T.grid(m, n) as (l0, l1):
with T.block("bias_row"):
i, j = T.axis.remap("SS", [l0, l1])
Expand All @@ -126,27 +135,29 @@ def relu():
with I.ir_module() as frame:
with T.prim_func():
T.func_name("relu")
A = T.arg("A", T.buffer_decl((m, n), A_TYPE)
) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)
) # pylint: disable=invalid-name
A = T.arg("A", T.buffer_decl((m, n), A_TYPE)) # pylint: disable=invalid-name
C = T.arg("C", T.buffer_decl((m, n), C_TYPE)) # pylint: disable=invalid-name
with T.grid(m, n) as (l0, l1):
with T.block("relu"):
i, j = T.axis.remap("SS", [l0, l1])
T.reads(A[i, j])
T.writes(C[i, j])
T.buffer_store(
C, T.max(A[i, j], T.cast(0, A_TYPE)), [i, j])
T.buffer_store(C, T.max(A[i, j], T.cast(0, A_TYPE)), [i, j])
return ib.get()["relu"]


@register_func("tvm.relax.cutlass.get_graph_pattern_code")
def get_graph_pattern_code(cutlass_op):
cutlass_op = [str(st) for st in cutlass_op]
pattern = "/".join(cutlass_op)
if pattern not in GRAPH_PATTERN_CODE_LIST:
raise tvm.TVMError("Cannot find graph pattern code for cutlass op: {}".format(cutlass_op))
return GRAPH_PATTERN_CODE_LIST["/".join(cutlass_op)]


GRAPH_PATTERN_CODE_LIST["dense_row_row_row"] = """
GRAPH_PATTERN_CODE_LIST[
"dense_row_row_row"
] = """
#define CUTLASS_ENABLE_CUBLAS 1
#define CUTLASS_NAMESPACE cutlass
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
Expand Down Expand Up @@ -219,8 +230,114 @@ def get_graph_pattern_code(cutlass_op):
TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _GEMM);
"""

GRAPH_PATTERN_CODE_LIST[
"dense_row_row_row/bias_row"
] = """
#define CUTLASS_ENABLE_CUBLAS 1
#define CUTLASS_NAMESPACE cutlass
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
#define NDEBUG
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
namespace {
using namespace tvm;
using namespace tvm::runtime;
// simple specialized impl, can be replaced by
// call into libraries.
void _HGEMM_BIAS(NDArray A, NDArray B, NDArray Bias, NDArray C) {
// A: [M, K], B: [K, N], BIAS: [1, N], C: [M, N]
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(Bias->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(A->shape[1], B->shape[0]);
int M = A->shape[0];
int K = A->shape[1];
int N = B->shape[1];
CHECK_EQ(C->shape[0], M);
CHECK_EQ(C->shape[1], N);
CHECK_EQ(Bias->shape[0], 1);
CHECK_EQ(Bias->shape[1], N);
CHECK_EQ(A.DataType(), DataType::Float(16));
CHECK_EQ(B.DataType(), DataType::Float(16));
CHECK_EQ(Bias.DataType(), DataType::Float(16));
CHECK_EQ(C.DataType(), DataType::Float(16));
// Define the GEMM operation
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombinationBias<
cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
false,
cutlass::arch::OpMultiplyAdd
>;
Gemm gemm_op;
cutlass::half_t alpha(1.0);
cutlass::half_t beta(0.0);
cutlass::layout::ColumnMajor::Stride::Index lda(K);
cutlass::layout::ColumnMajor::Stride::Index ldb(N);
cutlass::layout::ColumnMajor::Stride::Index ld_bias(N);
cutlass::layout::ColumnMajor::Stride::Index ldc(N);
cutlass::half_t* a = reinterpret_cast<cutlass::half_t*>(A->data);
cutlass::half_t* b = reinterpret_cast<cutlass::half_t*>(B->data);
cutlass::half_t* c = reinterpret_cast<cutlass::half_t*>(C->data);
cutlass::half_t* bias = reinterpret_cast<cutlass::half_t*>(Bias->data);
cutlass::Status status = gemm_op({
{M, N, K}, // GemmCoord problem_size_
{a, lda}, // TensorRef<ElementA const, LayoutA> ref_A_
{b, ldb}, // TensorRef<ElementB const, LayoutB> ref_B_
{bias, ld_bias}, // TensorRef<ElementC const, LayoutC> ref_C_
{c, ldc}, // TensorRef<ElementC, LayoutC> ref_D_
{alpha, beta} // typename EpilogueOutputOp::Params epilogue_
});
CHECK(status == cutlass::Status::kSuccess);
}
} // namespace
TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _HGEMM_BIAS);
"""

GRAPH_PATTERN_CODE_LIST["dense_row_row_row/bias_row/relu"] = """
GRAPH_PATTERN_CODE_LIST[
"dense_row_row_row/bias_row/relu"
] = """
#define CUTLASS_ENABLE_CUBLAS 1
#define CUTLASS_NAMESPACE cutlass
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
Expand Down
89 changes: 54 additions & 35 deletions python/tvm/relax/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
# pylint: disable=line-too-long,unused-argument
"""Default behavior for ops in mixed_precision pass. Import this file to use."""
import copy
from typing import List

from tvm.relay import Call
from tvm.relax.op import dense, matmul, conv2d
from tvm.relay.op import register_mixed_precision_conversion

# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
Expand All @@ -31,11 +34,7 @@
# Default lists inspired from TF's classifications:
# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
DEFAULT_ALWAYS_LIST = [
"relax.nn.dense",
"relax.nn.conv2d",
"relax.nn.matmul"
]
DEFAULT_ALWAYS_LIST = ["relax.nn.dense", "relax.nn.conv2d", "relax.nn.matmul"]
DEFAULT_FOLLOW_LIST = [
"relax.nn.flatten",
"relax.nn.batch_norm",
Expand Down Expand Up @@ -63,12 +62,7 @@
"relax.cast",
"relax.broadcast_to",
]
DEFAULT_NEVER_LIST = [
"relax.nn.softmax",
"relax.nn.layer_norm",
"relax.sum",
"relax.mean"
]
DEFAULT_NEVER_LIST = ["relax.nn.softmax", "relax.nn.layer_norm", "relax.sum", "relax.mean"]

# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType

Expand All @@ -81,49 +75,74 @@ def decorator(func):
return decorator


def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]:
def get_generic_out_dtypes(call_node: "relay.Call", expected_out_dtype: str) -> List:
"""A function which returns output dtypes in a way which works for most ops.
Parameters
---------
call_node: relay.Call
The call node containing the op.
mixed_precision_type: str
The target type to run the operation in.
expected_out_dtype: str
The output dtype to use.
Returns
-------
output_dtypes : [str, str]
A list of two strings. The first represents the datatype used for accumulation
in the operation. The second represents the actual output datatype.
output_dtypes : [str]
A list of output dtype.
"""
# Assume support accumulation dtypes <---> has out_dtype attr.
# This is because there is no better way right now to tell which ops support accumulating
# at different data types.
# Some discussion here about making this better is here:
# https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo
if hasattr(call_node.attrs, "out_dtype"):
# TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators
# return ["float32", mixed_precision_type]
out_dtype = "float32" if call_node.attrs.out_dtype == "" else call_node.attrs.out_dtype
return [out_dtype, mixed_precision_type]

# [accumulation_dtype, output_dtype] for the operations
return [mixed_precision_type, mixed_precision_type]
if call_node.op.name == "relax.nn.dense":
adjust_call = dense(
call_node.args[0],
call_node.args[1],
units=call_node.attrs.units,
out_dtype=expected_out_dtype,
)
elif call_node.op.name == "relax.nn.conv2d":
adjust_call = conv2d(
call_node.args[0],
call_node.args[1],
channels=call_node.attrs.channels,
kernel_size=call_node.attrs.kernel_size,
strides=call_node.attrs.strides,
padding=call_node.attrs.padding,
dilation=call_node.attrs.dilation,
groups=call_node.attrs.groups,
data_layout=call_node.attrs.data_layout,
kernel_layout=call_node.attrs.kernel_layout,
out_layout=call_node.attrs.out_layout,
out_dtype=expected_out_dtype,
)
elif call_node.op.name == "relax.nn.matmul":
adjust_call = matmul(
call_node.args[0],
call_node.args[1],
out_dtype=expected_out_dtype,
)
else:
raise ValueError("Unsupported op for get_generic_out_dtypes", call_node.op.name)
return [
True,
adjust_call,
]
else:
return [False, call_node]


# Functions for FTVMMixedPrecisionConversionType which
# Take in CallNodes and a DType and returns a conversion type,
# an accumulation dtype, and an output_dtype.
@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST)
def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type)
def generic_always_op(call_node: "relay.Call", expected_out_dtype: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, expected_out_dtype)


@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST)
def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type)
def generic_follow_op(call_node: "relay.Call", expected_out_dtype: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, expected_out_dtype)


@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)
def generic_never_op(call_node: "relay.Call", expected_out_dtype: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, expected_out_dtype)
12 changes: 10 additions & 2 deletions python/tvm/relax/transform/op_legalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import ir, te, topi
from tvm.ir import Attrs
from tvm.ir.module import IRModule
from tvm.tir.generic import cast

from ..analysis import remove_all_unused
from ..expr import Call, Expr, Function, Tuple, TupleGetItem
Expand Down Expand Up @@ -80,11 +81,18 @@ def _nn_relu(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Exp

def _nn_gelu(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr):
def gelu(x):
dtype = x.dtype
return te.compute(
x.shape,
lambda *i: 0.5
lambda *i: cast(0.5, dtype)
* x(*i)
* (1 + te.tanh(math.sqrt(2 / math.pi) * (x(*i) + 0.044715 * te.power(x(*i), 3)))),
* (
cast(1, dtype)
+ te.tanh(
cast(math.sqrt(2) / math.pi, dtype)
* (x(*i) + cast(0.044715, dtype) * te.power(x(*i), 3))
)
),
)

return bb.call_te(gelu, args[0])
Expand Down
Loading

0 comments on commit 65c2d2b

Please sign in to comment.