From 50a02aab9d8917e7be1359eb1d47b8462757e21c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Sun, 15 Oct 2023 14:26:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90=20batch=20normalizat?= =?UTF-8?q?ion=20=E7=9A=84=20cpu=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/01graph_topo/src/container.cc | 1 - src/01graph_topo/src/searcher.cc | 2 - src/04kernel/include/kernel/tensor.h | 4 +- .../kernels/batch_normalization/cpu_kernel.cc | 91 ++++++++++++++++--- .../kernels/batch_normalization/cpu_kernel.hh | 7 +- 5 files changed, 85 insertions(+), 20 deletions(-) diff --git a/src/01graph_topo/src/container.cc b/src/01graph_topo/src/container.cc index b9a00282..de86c18d 100644 --- a/src/01graph_topo/src/container.cc +++ b/src/01graph_topo/src/container.cc @@ -3,7 +3,6 @@ #include #include #include -#include namespace refactor::graph_topo { diff --git a/src/01graph_topo/src/searcher.cc b/src/01graph_topo/src/searcher.cc index 226345fd..07d94a45 100644 --- a/src/01graph_topo/src/searcher.cc +++ b/src/01graph_topo/src/searcher.cc @@ -1,8 +1,6 @@ #include "graph_topo/searcher.h" #include "refactor/common.h" #include -#include -#include namespace refactor::graph_topo { constexpr static idx_t EXTERNAL = std::numeric_limits::max(); diff --git a/src/04kernel/include/kernel/tensor.h b/src/04kernel/include/kernel/tensor.h index 176b61cc..9ddf6068 100644 --- a/src/04kernel/include/kernel/tensor.h +++ b/src/04kernel/include/kernel/tensor.h @@ -1,13 +1,13 @@ #ifndef KERNEL_TENSOR_H #define KERNEL_TENSOR_H -#include "refactor/common.h" #include "mem_manager/blob.hh" +#include "refactor/common.h" #include namespace refactor::kernel { - using Shape = absl::InlinedVector; + using Shape = absl::InlinedVector; enum class LayoutType : uint8_t { NCHW, diff --git a/src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc b/src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc index 3ff55963..d076a107 100644 --- a/src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc +++ b/src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc @@ -1,5 +1,6 @@ #include "cpu_kernel.hh" #include "refactor/common.h" +#include namespace refactor::kernel { using K = BatchNormalization; @@ -7,21 +8,25 @@ namespace refactor::kernel { K::BatchNormalization( float epsilon_, - std::array dts_, - Shape shape_, - uint32_t paramSize_) noexcept + DT dt0, + DT dt1, + DT dt2, + Shape shape_) noexcept : Kernel(), epsilon(epsilon_), - dts(dts_), - shape(std::move(shape_)), - paramSize(paramSize_) {} + dts{dt0, dt1, dt2}, + shape(std::move(shape_)) {} auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox { auto const &x = inputs[0].get(); auto const &scale = inputs[1].get(); auto const &mean = inputs[3].get(); - std::array dts{x.dataType, scale.dataType, mean.dataType}; - return std::make_unique(epsilon, dts, x.shape, scale.shape[0]); + if (!x.dataType.isCpuNumberic() || + !scale.dataType.isCpuNumberic() || + !mean.dataType.isCpuNumberic()) { + return nullptr; + } + return std::make_unique(epsilon, x.dataType, scale.dataType, mean.dataType, x.shape); } auto K::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -32,9 +37,23 @@ namespace refactor::kernel { auto K::description() const noexcept -> std::string_view { return "Performing batch normalization for non-training-mode on generic cpu"; } - auto K::lower() const noexcept -> Routine { + + template + Routine lowerTyped(Shape const &shape, float epsilon) { using namespace runtime; - return [](Resources &, Addresses inputs, Addresses outputs) { + using dt = typename primitive_t::type; + using t1 = typename primitive_t::type; + using t2 = typename primitive_t::type; + + auto n = shape[0], + c = shape[1], + dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies<>()), + sn = c * dims, + sc = dims; + return [n, c, sn, sc, epsilon]( + Resources &, + Addresses inputs, + Addresses outputs) { auto x = inputs[0], scale = inputs[1], bias = inputs[2], @@ -42,8 +61,58 @@ namespace refactor::kernel { var = inputs[4]; auto y = outputs[0]; - TODO(""); + struct Channel { + dt mean, scale, bias; + }; + std::vector channels(c); + auto scale_ = reinterpret_cast(scale), + bias_ = reinterpret_cast(bias); + auto mean_ = reinterpret_cast(mean), + var_ = reinterpret_cast(var); + for (auto i : range0_(c)) { + channels[i] = { + static_cast
(mean_[i]), + static_cast
(scale_[i]) / std::sqrt(static_cast
(var_[i]) + epsilon), + static_cast
(bias_[i]), + }; + } + // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B + auto x_ = reinterpret_cast
(x), + y_ = reinterpret_cast
(y); + for (auto in : range0_(n)) + for (auto ic : range0_(c)) + for (auto j : range0_(sc)) { + auto idx = in * sn + ic * sc + j; + auto [_, a, b] = channels[ic]; + y_[idx] = (x_[idx] - _) * a + b; + } }; } + auto K::lower() const noexcept -> Routine { + // clang-format off + static_assert(sizeof(decltype(DT::internal)) == 1); + #define MERGE(DT0, DT1, DT2) \ + (static_cast(DT0) ) \ + + (static_cast(DT1) << (1 * 8)) \ + + (static_cast(DT2) << (2 * 8)) + + #define CASE(DT0, DT1, DT2) \ + case MERGE(DT::DT0, DT::DT1, DT::DT2): \ + return lowerTyped(shape, epsilon) + + switch (MERGE(dts[0], dts[1], dts[2])) { + CASE(F32, F32, F32); + CASE(F32, F32, F64); + CASE(F32, F64, F32); + CASE(F32, F64, F64); + CASE(F64, F32, F32); + CASE(F64, F32, F64); + CASE(F64, F64, F32); + CASE(F64, F64, F64); + default: UNREACHABLE(); + } + // clang-format on + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh b/src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh index a8a317c2..e9a1aa48 100644 --- a/src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh +++ b/src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh @@ -1,19 +1,18 @@ #ifndef KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH #define KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH -#include "refactor/common.h" #include "kernel/kernel.h" #include "kernel/tensor.h" +#include "refactor/common.h" namespace refactor::kernel { struct BatchNormalization final : public Kernel { float epsilon; - std::array dts; + DataType dts[3]; Shape shape; - uint32_t paramSize; - BatchNormalization(float, std::array, Shape, uint32_t) noexcept; + BatchNormalization(float, DataType, DataType, DataType, Shape) noexcept; static KernelBox build(float, TensorRefs) noexcept; static size_t typeId() noexcept;