Skip to content

Commit

Permalink
feat: 完成 batch normalization 的 cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Oct 15, 2023
1 parent e0e930b commit 50a02aa
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 20 deletions.
1 change: 0 additions & 1 deletion src/01graph_topo/src/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <algorithm>
#include <numeric>
#include <sstream>
#include <utility>

namespace refactor::graph_topo {

Expand Down
2 changes: 0 additions & 2 deletions src/01graph_topo/src/searcher.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "graph_topo/searcher.h"
#include "refactor/common.h"
#include <algorithm>
#include <unordered_set>
#include <utility>

namespace refactor::graph_topo {
constexpr static idx_t EXTERNAL = std::numeric_limits<idx_t>::max();
Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/include/kernel/tensor.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#ifndef KERNEL_TENSOR_H
#define KERNEL_TENSOR_H

#include "refactor/common.h"
#include "mem_manager/blob.hh"
#include "refactor/common.h"
#include <absl/container/inlined_vector.h>

namespace refactor::kernel {

using Shape = absl::InlinedVector<int64_t, 4>;
using Shape = absl::InlinedVector<uint_lv2, 4>;

enum class LayoutType : uint8_t {
NCHW,
Expand Down
91 changes: 80 additions & 11 deletions src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
#include "cpu_kernel.hh"
#include "refactor/common.h"
#include <numeric>

namespace refactor::kernel {
using K = BatchNormalization;
using DT = DataType;

K::BatchNormalization(
float epsilon_,
std::array<DT, 3> dts_,
Shape shape_,
uint32_t paramSize_) noexcept
DT dt0,
DT dt1,
DT dt2,
Shape shape_) noexcept
: Kernel(),
epsilon(epsilon_),
dts(dts_),
shape(std::move(shape_)),
paramSize(paramSize_) {}
dts{dt0, dt1, dt2},
shape(std::move(shape_)) {}

auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox {
auto const &x = inputs[0].get();
auto const &scale = inputs[1].get();
auto const &mean = inputs[3].get();
std::array<DT, 3> dts{x.dataType, scale.dataType, mean.dataType};
return std::make_unique<K>(epsilon, dts, x.shape, scale.shape[0]);
if (!x.dataType.isCpuNumberic() ||
!scale.dataType.isCpuNumberic() ||
!mean.dataType.isCpuNumberic()) {
return nullptr;
}
return std::make_unique<K>(epsilon, x.dataType, scale.dataType, mean.dataType, x.shape);
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
Expand All @@ -32,18 +37,82 @@ namespace refactor::kernel {
auto K::description() const noexcept -> std::string_view {
return "Performing batch normalization for non-training-mode on generic cpu";
}
auto K::lower() const noexcept -> Routine {

template<decltype(DT::internal) T, decltype(DT::internal) T1, decltype(DT::internal) T2>
Routine lowerTyped(Shape const &shape, float epsilon) {
using namespace runtime;
return [](Resources &, Addresses inputs, Addresses outputs) {
using dt = typename primitive_t<T>::type;
using t1 = typename primitive_t<T1>::type;
using t2 = typename primitive_t<T2>::type;

auto n = shape[0],
c = shape[1],
dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies<>()),
sn = c * dims,
sc = dims;
return [n, c, sn, sc, epsilon](
Resources &,
Addresses inputs,
Addresses outputs) {
auto x = inputs[0],
scale = inputs[1],
bias = inputs[2],
mean = inputs[3],
var = inputs[4];
auto y = outputs[0];

TODO("");
struct Channel {
dt mean, scale, bias;
};
std::vector<Channel> channels(c);
auto scale_ = reinterpret_cast<t1 *>(scale),
bias_ = reinterpret_cast<t1 *>(bias);
auto mean_ = reinterpret_cast<t2 *>(mean),
var_ = reinterpret_cast<t2 *>(var);
for (auto i : range0_(c)) {
channels[i] = {
static_cast<dt>(mean_[i]),
static_cast<dt>(scale_[i]) / std::sqrt(static_cast<dt>(var_[i]) + epsilon),
static_cast<dt>(bias_[i]),
};
}
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B
auto x_ = reinterpret_cast<dt *>(x),
y_ = reinterpret_cast<dt *>(y);
for (auto in : range0_(n))
for (auto ic : range0_(c))
for (auto j : range0_(sc)) {
auto idx = in * sn + ic * sc + j;
auto [_, a, b] = channels[ic];
y_[idx] = (x_[idx] - _) * a + b;
}
};
}

auto K::lower() const noexcept -> Routine {
// clang-format off
static_assert(sizeof(decltype(DT::internal)) == 1);
#define MERGE(DT0, DT1, DT2) \
(static_cast<uint32_t>(DT0) ) \
+ (static_cast<uint32_t>(DT1) << (1 * 8)) \
+ (static_cast<uint32_t>(DT2) << (2 * 8))

#define CASE(DT0, DT1, DT2) \
case MERGE(DT::DT0, DT::DT1, DT::DT2): \
return lowerTyped<DT::DT0, DT::DT1, DT::DT2>(shape, epsilon)

switch (MERGE(dts[0], dts[1], dts[2])) {
CASE(F32, F32, F32);
CASE(F32, F32, F64);
CASE(F32, F64, F32);
CASE(F32, F64, F64);
CASE(F64, F32, F32);
CASE(F64, F32, F64);
CASE(F64, F64, F32);
CASE(F64, F64, F64);
default: UNREACHABLE();
}
// clang-format on
}

}// namespace refactor::kernel
7 changes: 3 additions & 4 deletions src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
#ifndef KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH
#define KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH

#include "refactor/common.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
#include "refactor/common.h"

namespace refactor::kernel {

struct BatchNormalization final : public Kernel {
float epsilon;
std::array<DataType, 3> dts;
DataType dts[3];
Shape shape;
uint32_t paramSize;

BatchNormalization(float, std::array<DataType, 3>, Shape, uint32_t) noexcept;
BatchNormalization(float, DataType, DataType, DataType, Shape) noexcept;

static KernelBox build(float, TensorRefs) noexcept;
static size_t typeId() noexcept;
Expand Down

0 comments on commit 50a02aa

Please sign in to comment.