diff --git a/src/04kernel/include/kernel/attributes/transpose_info.h b/src/04kernel/include/kernel/attributes/transpose_info.h index 00c49315..8c8a30bb 100644 --- a/src/04kernel/include/kernel/attributes/transpose_info.h +++ b/src/04kernel/include/kernel/attributes/transpose_info.h @@ -23,6 +23,8 @@ namespace refactor::kernel { TransposeInfo(DataType, Shape const &, Permutation const &); dim_t locate(dim_t) const noexcept; + TransposeInfo reform(dim_t maxblockSize) const noexcept; + void reformAssign(dim_t maxblockSize) noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/transpose_info.cc b/src/04kernel/src/attributes/transpose_info.cc index 9ae385a9..44a32b23 100644 --- a/src/04kernel/src/attributes/transpose_info.cc +++ b/src/04kernel/src/attributes/transpose_info.cc @@ -118,4 +118,24 @@ namespace refactor::kernel { return ans; } + TransposeInfo TransposeInfo::reform(dim_t maxblockSize) const noexcept { + auto ans = *this; + ans.reformAssign(maxblockSize); + return ans; + } + void TransposeInfo::reformAssign(dim_t maxblockSize) noexcept { + if (dims.empty()) { return; } + auto blockSize_ = std::gcd(blockSize, maxblockSize); + if (blockSize_ == blockSize) { return; } + auto times = blockSize / blockSize_; + blockCount *= times; + blockSize = blockSize_; + for (auto &s : dims) { + s.strideI *= times; + s.strideO *= times; + } + dims.resize(dims.size() + 1); + dims.rbegin()[0] = {1, 1}; + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/transpose/cuda_kernel.cc b/src/04kernel/src/kernels/transpose/cuda_kernel.cc index aed8d7b6..a557dc82 100644 --- a/src/04kernel/src/kernels/transpose/cuda_kernel.cc +++ b/src/04kernel/src/kernels/transpose/cuda_kernel.cc @@ -5,7 +5,7 @@ namespace refactor::kernel { using Info = TransposeInfo; K::TransposeCuda(Info info_) noexcept - : Kernel(), info(std::move(info_)) {} + : Kernel(), info(info_.reform(16)) {} auto K::build(Info info) noexcept -> KernelBox { #ifndef USE_CUDA