Skip to content

Commit

Permalink
fix(kernel): 为 Transpose 提供 reform
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 24, 2024
1 parent 85b5d82 commit eca616f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/04kernel/include/kernel/attributes/transpose_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace refactor::kernel {

TransposeInfo(DataType, Shape const &, Permutation const &);
dim_t locate(dim_t) const noexcept;
TransposeInfo reform(dim_t maxblockSize) const noexcept;
void reformAssign(dim_t maxblockSize) noexcept;
};

}// namespace refactor::kernel
Expand Down
20 changes: 20 additions & 0 deletions src/04kernel/src/attributes/transpose_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,24 @@ namespace refactor::kernel {
return ans;
}

TransposeInfo TransposeInfo::reform(dim_t maxblockSize) const noexcept {
auto ans = *this;
ans.reformAssign(maxblockSize);
return ans;
}
void TransposeInfo::reformAssign(dim_t maxblockSize) noexcept {
if (dims.empty()) { return; }
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return; }
auto times = blockSize / blockSize_;
blockCount *= times;
blockSize = blockSize_;
for (auto &s : dims) {
s.strideI *= times;
s.strideO *= times;
}
dims.resize(dims.size() + 1);
dims.rbegin()[0] = {1, 1};
}

}// namespace refactor::kernel
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/transpose/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace refactor::kernel {
using Info = TransposeInfo;

K::TransposeCuda(Info info_) noexcept
: Kernel(), info(std::move(info_)) {}
: Kernel(), info(info_.reform(16)) {}

auto K::build(Info info) noexcept -> KernelBox {
#ifndef USE_CUDA
Expand Down

0 comments on commit eca616f

Please sign in to comment.