Skip to content

Commit

Permalink
feat: support onnx::layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Apr 26, 2024
1 parent 0464e77 commit 9ab335a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "operators/gemm.hh"
#include "operators/global_pool.hh"
#include "operators/hard_sigmoid.hh"
#include "operators/layernorm.hh"
#include "operators/mat_mul.hh"
#include "operators/mat_mul_integer.hh"
#include "operators/pad.hh"
Expand Down Expand Up @@ -133,6 +134,7 @@ namespace refactor::onnx {
REGISTER(HardSigmoid , HardSigmoid );
REGISTER(Pad , Pad );
REGISTER(DepthToSpace , DepthToSpace );
REGISTER(LayerNormalization , Layernorm );
// clang-format on
#undef REGISTER
}
Expand Down
43 changes: 43 additions & 0 deletions src/07onnx/src/operators/layernorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "computation/operators/layernorm.h"
#include "common.h"
#include "layernorm.hh"

namespace refactor::onnx {
using Op = Layernorm;

Op::Layernorm(Int axis_, Float epsilon_)
: Operator(), axis(axis_), epsilon(epsilon_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto axis = attributes["axis"].int_();
auto epsilon = attributes["epsilon"].float_();
return OpBox(std::make_unique<Op>(axis, epsilon));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "onnx::LayerNormalization"; }
auto Op::valueDependentInputs() const -> InputVec { return {1}; }

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {

auto const &x = inputs[0];
auto const &scale = inputs[1];

if (!x.dataType.isFloat() ||
!scale.dataType.isFloat()) {
return Err(InferError(ERROR_MSG("Input data type not support")));
}

return Ok(Tensors{Tensor::share(x.dataType, x.shape, extractDependency(inputs))});
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::LayerNormalization;
return std::make_unique<Op_>(epsilon, axis);
}

}// namespace refactor::onnx
27 changes: 27 additions & 0 deletions src/07onnx/src/operators/layernorm.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef ONNX_LAYERNORM_HH
#define ONNX_LAYERNORM_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct Layernorm final : public Operator {
Int axis;
Float epsilon;

explicit Layernorm(Int, Float);

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InputVec valueDependentInputs() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_LAYERNORM_HH

0 comments on commit 9ab335a

Please sign in to comment.