You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
__global__voidmatmul(half *A, half *B, half *C, int M, int N, int K,
float alpha, float beta) {
// A is row-major// B is col-major// 128 threads [x, y, z] = [32, 2, 2]// threadblock mma: 128x128x32// warp mma: 64x64x16extern__shared__uint8_t shared_storage[];
half *SA = reinterpret_cast<half *>(shared_storage);
half *SB =
reinterpret_cast<half *>(shared_storage + MI * KI * sizeof(half));
float *SC = reinterpret_cast<float *>(shared_storage);
// Frag A 被分成 MII / WmmaM 个片段
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, wmmaM, wmmaN, wmmaK, half,
nvcuda::wmma::row_major>
FragA[MII / wmmaM];
// Frag B 被分成 NII / WmmaN 个片段
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, wmmaM, wmmaN, wmmaK, half,
nvcuda::wmma::col_major>
FragB[NII / wmmaN];
// 累加器被分为 MII / WmmaM * NII / WmmaN 个片段
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, wmmaM, wmmaN, wmmaK,
float>
Accum[MII / wmmaM * NII / wmmaN];
// 初始化累加器for (int mii = 0; mii < MII / wmmaM; mii += 1) {
for (int nii = 0; nii < NII / wmmaN; nii += 1) {
nvcuda::wmma::fill_fragment(Accum[mii * (NII / wmmaN) + nii], 0.0);
}
}
// 沿着 K 纬度进行切分for (int ko = 0; ko < K / KI; ko += 1) {
// 加载共享内存loadSmemA(SA, A, M, K, ko);
loadSmemB(SB, B, N, K, ko);
__syncthreads();
// 沿着 KTILE 进行迭代for (int ki = 0; ki < KI / KII; ki += 1) {
// 64x64x16 mma for each warp// 加载片段loadFragA(FragA, SA, ki);
loadFragB(FragB, SB, ki);
for (int mii = 0; mii < MII / wmmaM; mii += 1) {
for (int nii = 0; nii < NII / wmmaN; nii += 1) {
// 16x16x16 for each wmmanvcuda::wmma::mma_sync(Accum[mii * (NII / wmmaN) + nii],
FragA[mii], FragB[nii],
Accum[mii * (NII / wmmaN) + nii]);
}
}
}
}
storeAccum(SC, Accum);
__syncthreads();
storeSmemC(C, SC, M, N);
}
// The code section below describes datatype for input, output matrices and// computation between elements in input matrices.using ElementAccumulator = float; // <- data type of accumulatorusing ElementComputeEpilogue =
ElementAccumulator; // <- data type of epilogue operationsusing ElementInputA =
cutlass::half_t; // <- data type of elements in input matrix Ausing ElementInputB =
cutlass::half_t; // <- data type of elements in input matrix Busing ElementOutput = float; // <- data type of elements in output matrix D// The code section below describes matrix layout of input and output matrices.// Column Major for Matrix A, Row Major for Matrix B and Row Major for Matrix Cusing LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
// This code section describes whether you want to use tensor cores or regular// SIMT cores on GPU SMusing MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture numberusing SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will computeusing ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N =// 128, K = 32// This code section describes tile size a warp will computeusing ShapeMMAWarp =
cutlass::gemm::GemmShape<64, 64,
64>; // <- warp tile M = 64, N = 64, K = 32// This code section describes the size of MMA opusing ShapeMMAOp =
cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 4// This code section describes how threadblocks are scheduled on GPUusing SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??// This code section describes ?using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per// vectorized memory access. For half// precision, it's 8 elements. This// becomes the vector width of math// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear// combination function// Number of pipelines you want to useconstexprint NumStages = 2;
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock,
ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>;
classProfilerEngine(object):
"""Compile and run a given profiler executable."""def__init__(self, cuda_arch, cutlass_path, binary_prefix):
self.cuda_arch=cuda_archself.binary_prefix=binary_prefixself.cutlass=cutlass_pathself.cflags="-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
cutlass=cutlass_path
)
self.cflags+=" -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"self.cflags+=" -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
arch=cuda_arch
)
self.cflags+=" -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"self.cmd="nvcc {cflags} {src} -o {output}"def_compile(self, op):
os.makedirs(self.binary_prefix, exist_ok=True)
opath=os.path.join(self.binary_prefix, op["name"])
ifos.path.exists(opath):
returnfi=tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
fi.write(op["src"])
fi.close()
cmd=self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
os.system(cmd)
os.unlink(fi.name)
defcompile_all(self, ops, use_multiprocessing=False):
"""Compile all profiler executables."""ifuse_multiprocessing:
pool=multiprocessing.Pool(multiprocessing.cpu_count())
pool.map(self._compile, ops)
else:
foropinops:
self._compile(op)
defevaluate(self, op_name, args):
"""Run the profiler executable corresponding to op_name with args."""opath=os.path.join(self.binary_prefix, op_name)
cmd= [opath]
ifargsisnotNone:
cmd.append(str(args[0]))
cmd.append(str(args[1]))
cmd.append(str(args[2]))
iflen(args) >3:
cmd.append(str(args[3]))
try:
sp=subprocess.run(cmd, capture_output=True, check=True)
rt=float(sp.stdout)
print(op_name, rt)
exceptsubprocess.CalledProcessError:
rt=-1returnrtclassCutlassGemmProfiler(object):
"""Profile all candidate kernels and select the best one."""def__init__(self, sm, cutlass_path, binary_path):
assertsminGENERATOR_FUNC_TABLE, "sm%d not supported yet."%smself.engine=ProfilerEngine(sm, cutlass_path, binary_path)
self.sm=smdefcheck_align(self, op_name, M):
"""Filter out kernels that cannot be supported."""aligns=re.findall(r"align[1|2|4|8]", op_name)
assertlen(aligns) ==1align=int(aligns[0][-1])
ifM%align!=0:
returnFalsereturnTruedefprofile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
"""Profile and select the best kernel from candidate kernels. If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ops=GENERATOR_FUNC_TABLE[self.sm](out_dtype)
ops=list(filter(lambdaop: self.check_align(op["name"], M), ops))
foropinops:
op["runtime"] =-1self.engine.compile_all(ops, use_multiprocessing)
foropinops:
out=self.engine.evaluate(op["name"], [M, N, K])
op["runtime"] =outifout>0andprofile_allisFalse:
breakvalid_ops=filter(lambdaop: op["runtime"] >0, ops)
output=sorted(valid_ops, key=lambdai: i["runtime"])
returnoutput[0]
这个函数被 build.py 调用并生成性能最好的参数:
deftune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. Parameters ---------- mod : IRModule The Relay module with cutlass partitions. sm : int An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. profile_all : bool Whether or not profile all candidate kernels, or stop profiling after the first applicable kernel is found. use_multiprocessing : bool Whether or not compile profiler executables for different kernels in parallel. tmp_dir : string, optional A temporary directory where intermediate compiled artifacts will be stored. Returns ------- mod : IRModule The updated module annotated with cutlass profiling information. num_cutlass_partition : int The number of partitioned functions created for CUTLASS. """cutlass_profiler=CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir)
num_cutlass_partition=0forvarinmod.get_global_vars():
fun_name=var.name_hintfunc=mod[fun_name]
annotator=GemmAnnotator()
if"cutlass"infun_name:
num_cutlass_partition+=1annotator.visit(func)
# call cutlass profiler to find best settings, update attrnew_attrs= {}
new_attrs.update(annotator.signature)
forkeyinfunc.attrs.keys():
new_attrs[key] =func.attrs[key]
# call profilerarg0_shape=new_attrs["arg0_shape"]
arg1_shape=new_attrs["arg1_shape"]
MM=arg0_shape[0]
KK=arg0_shape[1]
NN=arg1_shape[0]
out=cutlass_profiler.profile(
MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing
)
ifnew_attrs["op_type"] =="cutlass.dense":
new_attrs["cutlass_op_def"] =out["opdef"]
elifnew_attrs["op_type"] =="cutlass.dense_bias":
new_attrs["cutlass_op_def"] =out["opdef_bias"]
elifnew_attrs["op_type"] =="cutlass.dense_bias_relu":
new_attrs["cutlass_op_def"] =out["opdef_bias_relu"]
elif"cutlass.dense_bias_gelu"innew_attrs["op_type"]:
new_attrs["cutlass_op_def"] =out["opdef_bias_gelu"]
else:
raiseValueError("%s pattern is not implemented."%new_attrs["op_type"])
new_attrs["cutlass_op_name"] =out["name"]
print("The best kernel is "+new_attrs["cutlass_op_name"])
ifnew_attrs["cutlass_op_name"].find("_tn_align") >0:
new_attrs["lda"] ="K"new_attrs["ldb"] ="K"new_attrs["ldc"] ="N"elifnew_attrs["cutlass_op_name"].find("_nt_align") >0:
new_attrs["lda"] ="M"new_attrs["ldb"] ="N"new_attrs["ldc"] ="N"else:
raiseValueError("%s unsupported operation"%new_attrs["cutlass_op_name"])
new_attrs=tvm.ir.make_node("DictAttrs", **new_attrs)
new_func=relay.Function(
func.params,
func.body,
ret_type=func.ret_type,
type_params=func.type_params,
attrs=new_attrs,
)
mod.update_func(var, new_func)
returnmod, num_cutlass_partitiondefbuild_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"):
"""Compile CUTLASS kernels in lib and return the runtime module ready to run. Parameters ---------- lib : GraphExecutorFactoryModule The output from relay.build containing compiled host code and non-cutlass kernels. sm : int An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. tmp_dir : string, optional A temporary directory where intermediate compiled artifacts will be stored. lib_path : string, optional The path to a shared library which will be generated as the result of the build process Returns ------- updated_lib : runtime.Module The updated module with compiled cutlass kernels. """cutlass_path="../../../3rdparty/cutlass/include"cutlass_util_path="../../../3rdparty/cutlass/tools/util/include"kwargs= {}
kwargs["cc"] ="nvcc"kwargs["options"] = [
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-gencode=arch=compute_%d,code=[sm_%d,compute_%d]"% (sm, sm, sm),
"-Xcompiler=-fPIC",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
"-O3",
"-std=c++14",
"-I"+cutlass_path,
"-I"+cutlass_util_path,
]
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
returnruntime.load_module(lib_path)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
代码生成设计提案
本节以 Nvidia 平台上的 GEMM 代码生成为例,讨论代码生成提案。
基于子图的生成
默认的算子级在全局内存进行生成,即给定了
input
与output
张量,生成__global__
级别的 kernel 并对整个计算图进行连接。然而,有时候我们需要对与不同算子进行融合以减少同步与内存的开销。因此我们需要引入子图结构,并在子图中规定融合的级别,例如:其中
subgraphs
一个子图递归结构,下一级既可以是一个子图,也可以是一个单独的算子,并在融合的内存级进行逐层下降。原语设计
这个代码生成提案希望设计不同的原语,通过组合以生成不同的融合算子的代码生成。接下来以 Nvidia 在 Tensor Core 以及 Cute,Cutlass 等实现来进行说明。Tensor Core,Cute,Cutlass 的易用性从上到下一次递增,但是灵活度则依次递减。
Tensor Core
GEMM
上面是一个简单地使用 Tensor Core 的用 Sliced-K 方式实现 GEMM Kernel 的,针对这段代码,涉及到一系列原语:
CudaVar
: Cuda 变量,涉及到声明、初始化、填充等操作,需要针对 Cuda 的不同内存级进行定义。iteration
: 迭代操作对象,接受迭代变量、start、end、step 作为对象内变量并可以生成迭代操作以及获取迭代变量。sync
: 同步原语操作,基于不同内存级进行生成,例如__syncthreads()
以及cudaDeviceSynchnorize()
。mma
: MMA 是 CUDA Tensor Core 的内置操作,分为两种,一种是由 Nvidia 提供的 WMMA api,直接掉库,使用简单,但性能不如 MMA,一种是手写 CUDA PTX 实现 MMA 操作,实现困难,但性能很高。Load
/Store
: 内存加载存储操作,也是最为复杂的原语设计,涉及到启动 kernel 的 thread layout 以及 memory layout,例如,为了满足 MMA 的内存 layout 要求,需要将一个 [MTILE, NTILE] 的两维度共享内存铺成 [MTILE/WmmaM, NTILE/WmmaN, WmmaM, WmmN] 的四维度内存。GEMM fused
这是一个手写的 GEMM 在共享内存层级融合 GELU 以及 GEMM 的 Tensor Core 实现。为了生成上述的代码,我们需要指定融合的层级,并规定输出输出的变量层级。
例如,对于一阶段的 GEMM 来说,两个输入变量都是全局内存,输出变量则为共享内存,这意味着对于第一阶段 GEMM 我们不应当进行同步,而应该生成到共享内存就停止生成。
对于第二阶段的共享内存,两个输入变量的 A 为共享内存,B 为全局内存,C 为共享内存,融合级别为共享内存级,这意味着 A 不需要从全局内存进行加载,只需要将 B 加载入共享内存并重新执行一遍 Sliced-K 的 GEMM 即可。当然,这要求第一阶段 GEMM 与第二阶段 GEMM 有相同的 threadBlock 组织,即 ThreadBlock_0 = ThreadBlock_1。
TVM 支持 Tensor Core 相关:
4136 这个 PR 主要是增加了一些
mma_sync
,fill_fragment
这些原语。4105 这个 issue 一开始的提案是通过 AST 的形状来判断,但是这样并不能准确地识别出 Tensor Core,由于在 4136 里已经提出了新的 intrinsics,而陈天奇认为这是一类新的硬件,应该使用新的原语做 codegen。4234 基于原语重新组织 AST 进行生成。Cute
TODO
Cutlass
Performance Evluation
TVM 支持 cutlass 相关:
PR 9261 介绍了如何通过 BYOC 在 tvm 引入 cutlass 的代码生成,这个之前在 9147 的讨论中讨论过。后端增加的代码在
src/relay/backend/contrib/cutlass/codegen.cc
中:写法还是比较初级的,主要是基于上层传来的
attrs
来 print 到文件描述符中。同时也对简单的融合,例如 gelu,biasadd 这种做了尾声处理。GenerateCompositeFunctionCall
主要通过不同的 pattern 调用GenerateBody
,GenerateBody
再进行一些分析调用DenseOp
。python/tvm/contrib/cutlass/gen_gemm.py
描述了如何做 cutlass 在 python 做代码生成。create_gemm_operator
的函数签名如下:在
create_gemm_operator
基于layouts
,tile_descriptions
以及alignment_contraints
进行三次迭代,看起来是为了枚举所有的基于 cutlass 的 layout 生成的代码,随后创建GemmOperation
并进行生成:可以看到在
create_gemm_operator
中调用了emit
函数,这主要是为了做 profile 以选出最好的 layout,具体的代码规定在gemm_profiler
中:可以看到首先规定了一些做 profile 所需要的性能评估的模版,随后基于参数做实际的生成并作性能评估。同时在
gen_gemm.py
中规定了如何做 profile 并取出性能最好的参数:这个函数被
build.py
调用并生成性能最好的参数:GEMM Optimization
Beta Was this translation helpful? Give feedback.
All reactions