diff --git a/src/04kernel/CMakeLists.txt b/src/04kernel/CMakeLists.txt index 77b655c0..e63009a0 100644 --- a/src/04kernel/CMakeLists.txt +++ b/src/04kernel/CMakeLists.txt @@ -7,6 +7,9 @@ if(USE_CUDA) file(GLOB_RECURSE KERNEL_CUDA_SRC src/*.cu) add_subdirectory(cuda) endif() +if(USE_BANG) + file(GLOB_RECURSE KERNEL_BANG_SRC src/*.mlu) +endif() add_library(kernel STATIC ${KERNEL_SRC} ${KERNEL_CUDA_SRC}) target_link_libraries(kernel PUBLIC runtime) diff --git a/src/04kernel/src/collectors/softmax.cc b/src/04kernel/src/collectors/softmax.cc index 020bc6de..6f7b0a0c 100644 --- a/src/04kernel/src/collectors/softmax.cc +++ b/src/04kernel/src/collectors/softmax.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/softmax.h" +#include "../kernels/softmax/bang_kernel.hh" #include "../kernels/softmax/cnnl_kernel.hh" #include "../kernels/softmax/cpu_kernel.hh" #include "../kernels/softmax/cuda_kernel.hh" @@ -33,6 +34,9 @@ namespace refactor::kernel { if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) { ans.emplace_back(std::move(ptr)); } + if (auto ptr = SoftmaxBang::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } break; } default: diff --git a/src/04kernel/src/kernels/softmax/bang_kernel.cc b/src/04kernel/src/kernels/softmax/bang_kernel.cc new file mode 100644 index 00000000..e39e25a2 --- /dev/null +++ b/src/04kernel/src/kernels/softmax/bang_kernel.cc @@ -0,0 +1,29 @@ +#include "bang_kernel.hh" + +namespace refactor::kernel { + using K = SoftmaxBang; + + K::SoftmaxBang(SoftmaxInfo info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(SoftmaxInfo info) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + return info.type.isFloat() + ? std::make_unique(std::move(info)) + : nullptr; + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing Softmax using BANG"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/softmax/bang_kernel.hh b/src/04kernel/src/kernels/softmax/bang_kernel.hh new file mode 100644 index 00000000..b2b5fc03 --- /dev/null +++ b/src/04kernel/src/kernels/softmax/bang_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_SOFTMAX_BANG_HH +#define KERNEL_SOFTMAX_BANG_HH + +#include "cnnl.h" +#include "cnrt.h" +#include "kernel/attributes/softmax_info.h" +#include "kernel/collectors/softmax.h" +namespace refactor::kernel { + + struct SoftmaxBang final : public Kernel { + SoftmaxInfo info; + + SoftmaxBang(SoftmaxInfo) noexcept; + static KernelBox build(SoftmaxInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif//KERNEL_SOFTMAX_BANG_HH diff --git a/src/04kernel/src/kernels/softmax/bang_kernel.mlu b/src/04kernel/src/kernels/softmax/bang_kernel.mlu new file mode 100644 index 00000000..f19c898c --- /dev/null +++ b/src/04kernel/src/kernels/softmax/bang_kernel.mlu @@ -0,0 +1,599 @@ +#include "bang_kernel.hh" +#include +#include +#define EPS 1e-7 +const int NRAM_MAX_SIZE = 1024 * 256;//Apply for maximum memory in advance from NRAM +const int nramNum = NRAM_MAX_SIZE / sizeof(float); +const int SRC_MAX_SIZE = 1024 * 32;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2 +const int maxNum = SRC_MAX_SIZE / sizeof(float); +const int warpSize = 32; + +namespace refactor::kernel { + using namespace runtime; + + __mlu_device__ void softmaxKernelAxis_m(float *destination, float *source, int frontsize, int dimsize, int stride, int strideS) { + 0= maxNum) { + float *src = nram_buffer; + float *tmpSum = src + maxNum; + float *tmpNewMax = tmpSum + maxNum; + float *tmpOldMax = tmpNewMax + maxNum; + + int remain = stride % maxNum; + int repeat = (stride - remain) / maxNum; + + int taskRemain = frontsize % taskDim; + int stepEasy = (frontsize - taskRemain) / taskDim; + int stepHard = stepEasy + 1; + + int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy); + source = source + indStart * dimsize * stride; + destination = destination + indStart * dimsize * stride; + + for (int ind = taskId; ind < frontsize; ind += taskDim) { + int frontIdx = ind * dimsize * stride; + for (int j = 0; j < repeat; j++) { + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_zero(src, maxNum); + for (int i = 0; i < dimsize; i++) { + __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum); //exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + __bang_active_recip_greater_1(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); + } + } + if (remain) { + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_value(src, maxNum, -INFINITY); + for (int i = 0; i < dimsize; i++) { + __memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + //------------------- + __bang_active_recip_greater_1(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM); + } + //--------------------- + } + } + } else if (stride < maxNum && dimsize * stride >= maxNum) { + + + float *src = nram_buffer; + float *tmp = src + maxNum; + float *tmpOldMax = tmp + strideS; + float *tmpNewMax = tmpOldMax + strideS; + float *tmpSum = tmpNewMax + strideS; + + int multiple = maxNum / stride; + int size = multiple * stride; //The maximum amount of data that can be stored in an SRC + int remain = dimsize % multiple; //If it cannot be divisible, this part of the data needs special processing + int repeat = (dimsize - remain) / multiple;//The total number of loops required to load the entire dimsize + + int taskRemain = frontsize % taskDim; + int stepEasy = (frontsize - taskRemain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < taskRemain ? stepHard : stepEasy);//The number of frontsize processed per taskId + int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy); + source = source + indStart * dimsize * stride; + destination = destination + indStart * dimsize * stride; + //printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain); + for (int ind = 0; ind < step; ind++) { + int frontIdx = ind * dimsize * stride; + + __bang_write_value(tmpNewMax, strideS, -INFINITY);//Must be initialized to negative infinity + __bang_write_value(tmp, strideS, -INFINITY); //Must be initialized to negative infinity + __bang_write_zero(tmpSum, strideS); //Must be initialized to zero + + for (int j = 0; j < repeat; j++) { + __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); + for (int m = 0; m < multiple; m++) { + __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); + + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result + + __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 + __bang_active_exp_less_0(tmp, tmp, strideS); + if (j != 0 || m != 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + //if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]); + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + } + //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]); + if (remain) { + __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM); + for (int m = 0; m < remain; m++) { + __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 + __bang_active_exp_less_0(tmp, tmp, strideS); + if (repeat != 0 || m != 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + } + + //At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum + //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]); + __bang_active_recip_greater_1(tmpSum, tmpSum, strideS); + //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]); + if (remain) { + for (int m = 0; m < remain; m++) { + __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS); + __bang_active_exp_less_0(tmp, tmp, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS); + __memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + } + } + for (int j = 0; j < repeat; j++) { + __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); + for (int m = 0; m < multiple; m++) { + __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); + + __bang_sub(tmp, tmp, tmpNewMax, strideS); + __bang_active_exp_less_0(tmp, tmp, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS); + __memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + } + } + } + } else if (dimsize * stride < maxNum) { + + float *src = nram_buffer; + float *tmp = src + maxNum; + float *tmpOldMax = tmp + strideS; + float *tmpNewMax = tmpOldMax + strideS; + float *tmpSum = tmpNewMax + strideS; + int behindsize = dimsize * stride; + int multiple = maxNum / behindsize;//Represents the amount that a maxNum can share in frontsize + + int remainF = frontsize % (taskDim * multiple); + int remainT = remainF % taskDim; + int stepEasy = (remainF - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int taskRepeat = (frontsize - remainF) / (taskDim * multiple); + //At this point, corresponding to frontsize, the amount of data processed by each taskId is taskRepeat * multiple+step + int startHard = taskId * (taskRepeat * multiple + stepHard); + int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy); + int indStart = (taskId < remainT ? startHard : startEasy); + source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds + destination = destination + indStart * behindsize; + int tid; + for (int s = 0; s < taskRepeat; s++) { + tid = s * multiple * behindsize; + __memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM); + for (int m = 0; m < multiple; m++) { + __bang_write_zero(tmpSum, strideS); + __bang_write_value(tmp, strideS, -INFINITY); + __bang_write_value(tmpNewMax, strideS, -INFINITY); + for (int i = 0; i < dimsize; i++) { + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS); //x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + __bang_active_recip_greater_1(tmpSum, tmpSum, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized + //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS); //x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + __bang_mul(tmp, tmp, tmpSum, strideS); + //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM); + } + } + __memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM); + } + //__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize); + if (step) { + tid = taskRepeat * multiple * behindsize; + __memcpy(src, source + tid, step * behindsize * sizeof(float), GDRAM2NRAM); + for (int m = 0; m < step; m++) { + __bang_write_zero(tmpSum, strideS); + __bang_write_value(tmp, strideS, -INFINITY); + __bang_write_value(tmpNewMax, strideS, -INFINITY); + for (int i = 0; i < dimsize; i++) { + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS); //x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + //__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]); + __bang_active_recip_greater_1(tmpSum, tmpSum, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized + //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS); //x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + __bang_mul(tmp, tmp, tmpSum, strideS); + //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM); + } + } + __memcpy(destination + tid, src, step * behindsize * sizeof(float), NRAM2GDRAM); + } + } + } + + __mlu_device__ void softmaxKernelAxis_e(float *destination, float *source, int othersize, int dimsize, int dimS) {axis = -1 + int multiple = maxNum / dimsize; + int size = taskDim * multiple; + int remainS = othersize % size; + int taskRepeat = (othersize - remainS) / size; + int remainT = remainS % taskDim; + int stepEasy = (remainS - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + //The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step + //Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize + int startHard = taskId * (taskRepeat * multiple + stepHard); + int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy); + int indStart = (taskId < remainT ? startHard : startEasy); + source = source + indStart * dimsize; + destination = destination + indStart * dimsize; + + __nram__ float nram_buffer[nramNum]; + + float *src = nram_buffer; + float *tmp = src + maxNum; + float *destSum = tmp + dimS; + int remainDim = dimsize % dimS;//Dimsize may not be a power of 2 + int repeatDim = (dimsize - remainDim) / dimS; + __nram__ float destSumFinal[warpSize];//Reduce destSum to destFinal [0] + __nram__ float srcMax[2]; + __nram__ float destOldMax; + __nram__ float destNewMax; + //printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize); + int tid; + for (int s = 0; s < taskRepeat; s++) { + tid = s * multiple * dimsize; + __memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM); + for (int j = 0; j < multiple; j++) { + __bang_write_zero(destSum, dimS); + __bang_write_zero(destSumFinal, warpSize); + destNewMax = -INFINITY; + + for (int i = 0; i < repeatDim; i++) { + __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + if (destNewMax < srcMax[0]) { + destNewMax = srcMax[0]; + } + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + if (i > 0) { + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); + } + __bang_add(destSum, destSum, tmp, dimS); + destOldMax = destNewMax; + } + if (remainDim) { + __bang_write_value(tmp, dimS, -INFINITY); + __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + if (destNewMax < srcMax[0]) { + destNewMax = srcMax[0]; + } + __bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax + __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + if (repeatDim > 0) { + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); + } + __bang_add(destSum, destSum, tmp, dimS); + destOldMax = destNewMax; + } + + int segNum = dimS / warpSize;//Starting numerical summation + for (int strip = segNum / 2; strip > 0; strip = strip / 2) { + for (int i = 0; i < strip; i++) { + __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum + if (remainDim) { + destSumFinal[0] = destSumFinal[0] - (dimS - remainDim); + } + //Now let's start writing back the data + float globalSumInv = 1.0 / destSumFinal[0]; + if (remainDim) { + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + __memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM); + } + for (int i = 0; i < repeatDim; i++) { + __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + __memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM); + } + } + //it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM, + //there may be a situation where src writes back to GDRAM before modifying the src data + } + if (step) {//Step targets parts of othersize that cannot be divided by multiple * dimsize + tid = taskRepeat * multiple * dimsize; + __memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM); + for (int j = 0; j < step; j++) { + __bang_write_zero(destSum, dimS); + __bang_write_zero(destSumFinal, warpSize); + destNewMax = -INFINITY; + for (int i = 0; i < repeatDim; i++) {//RepeatDim refers to the total number of cycles required to read the current dimsize data using dimS after fixing otherIdx + __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + if (destNewMax < srcMax[0]) { + destNewMax = srcMax[0]; + } + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + if (i > 0) { + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); + } + __bang_add(destSum, destSum, tmp, dimS); + destOldMax = destNewMax; + } + if (remainDim) {//RemainDim refers to the part of dimsize that cannot be divided by dimS after fixing otherIdx + __bang_write_value(tmp, dimS, -INFINITY); + __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + if (destNewMax < srcMax[0]) { + destNewMax = srcMax[0]; + } + + __bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax + __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + if (repeatDim > 0) { + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); + } + __bang_add(destSum, destSum, tmp, dimS); + destOldMax = destNewMax; + } + int segNum = dimS / warpSize;//Starting numerical summation + for (int strip = segNum / 2; strip > 0; strip = strip / 2) { + for (int i = 0; i < strip; i++) { + __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, warpSize); + //At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum + if (remainDim) { + destSumFinal[0] = destSumFinal[0] - (dimS - remainDim); + } + //__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]); + float globalSumInv = 1.0 / destSumFinal[0]; + if (remainDim) { + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + __memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM); + } + for (int i = 0; i < repeatDim; i++) { + __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); + __bang_sub_scalar(tmp, tmp, destNewMax, dimS); + __bang_active_exp_less_0(tmp, tmp, dimS); + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + __memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM); + } + } + } + } + __mlu_device__ void softmaxKernelAxis_s(float *destination, float *source, int othersize, int dimsize, int stride) {axis = 0 + __nram__ float src[maxNum]; //Transfer maxNum data to NRAM every time + __nram__ float tmpSum[maxNum]; + __nram__ float tmpNewMax[maxNum]; + __nram__ float tmpOldMax[maxNum]; + + int remain = othersize % taskDim; + int stepEasy = (othersize - remain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element + int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy); + int remainNram = step % maxNum; + int repeat = (step - remainNram) / maxNum; + + //__bang_printf("taskId:%d, repeat:%d, step:%d, indStart:%d, remainNram:%d\n", taskId, repeat, step, indStart, remainNram); + for (int j = 0; j < repeat; j++) { + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + for (int i = 0; i < dimsize; i++) { + __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum); //exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + __bang_active_recip_greater_1(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); + } + } + if (remainNram) { + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_zero(src, maxNum); + + + for (int i = 0; i < dimsize; i++) { + __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + if (i > 0) { + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum); //oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum); //sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + + __bang_active_recip_greater_1(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); + for (int i = 0; i < dimsize - 1; i++) { + __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum); //x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); + } + } + } + + __mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride) { + if (axis == nDim - 1) { + int dimS; + float mi = log2(dimsize); + if (floor(mi) == mi) { + dimS = dimsize; + } else { + dimS = pow(2, floor(mi) + 1); + } + if (dimS < warpSize) { + dimS = warpSize; + } + softmaxKernelAxis_e(mlu_destination, mlu_src, othersize, dimsize, dimS); + } else if (axis == 0) { + softmaxKernelAxis_s(mlu_destination, mlu_src, othersize, dimsize, stride); + } else { + float mi = log2(stride); + int strideS; + if (floor(mi) == mi) { + strideS = stride; + } else { + strideS = pow(2, floor(mi) + 1); + } + softmaxKernelAxis_m(mlu_destination, mlu_src, frontsize, dimsize, stride, strideS); + } + } + + + template + Routine lowerTypedBang(SoftmaxInfo info) { + using namespace runtime; + + return [info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto mlu_src = reinterpret_cast(inputs[0]); + auto mlu_destination = reinterpret_cast(outputs[0]); + int dimsize = info.mid; + int stride = info.post; + int frontsize = info.pre; + int othersize = frontsize * stride; + int numBlocks = info.pre * info.post; + int nDim = 4; + int axis = 1; + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + res.fetchOrStore(); + cnrtQueue_t queue; + cnnlHandle_t handle = res.fetchOrStore()->handle; + cnnlGetQueue(handle, &queue); + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = CNRT_FUNC_TYPE_UNION1; + + softmaxUnion1<<>>(mlu_destination, mlu_src, nDim, axis, othersize, frontsize, dimsize, stride); + }; + } + + auto SoftmaxBang::lower(Resources &res) const noexcept -> RoutineWorkspace { + switch (info.type.internal) { + case DataType::F32: + return lowerTypedBang(info); + case DataType::F64: + return lowerTypedBang(info); + case DataType::FP16: + return lowerTypedBang(info); + case DataType::BF16: + return lowerTypedBang(info); + default: + UNREACHABLE(); + } + } + + +}//namespace refactor::kernel diff --git a/src/04kernel/test/kernels/softmax/test_bang.cpp b/src/04kernel/test/kernels/softmax/test_bang.cpp new file mode 100644 index 00000000..806f169f --- /dev/null +++ b/src/04kernel/test/kernels/softmax/test_bang.cpp @@ -0,0 +1,54 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/softmax/bang_kernel.hh" +#include "../../../src/kernels/softmax/cpu_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, SoftmaxBang) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + dim_t axis = 1; + int nDim = 5; + auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis)); + auto kBang = SoftmaxBang::build(SoftmaxInfo(*xTensor, axis)); + ASSERT_TRUE(kCpu && kBang); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto rBang = kBang->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto gpuIn = dev.malloc(xTensor->bytesSize()), + gpuOut = dev.malloc(outTensor->bytesSize()); + // put input data + std::vector + data(xTensor->elementsSize(), 0), + cpuOut(outTensor->elementsSize()); + gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); + // inference + { + void const *inputs[]{data.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + + rBang(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(outTensor->elementsSize()); + gpuOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif