From 10e983280e279699e482415a7580785dd9f566b5 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Thu, 25 Jul 2024 14:43:26 +0800 Subject: [PATCH] feat: support leakyrelu and floor op --- scripts/onnx/to_onnx.py | 12 ++++++ .../include/kernel/collectors/simple_unary.h | 1 + .../computation/operators/leaky_relu.h | 23 +++++++++++ src/05computation/src/operators/leaky_relu.cc | 20 ++++++++++ .../src/operators/simple_unary.cc | 6 +++ src/07onnx/src/operators.cpp | 3 ++ src/07onnx/src/operators/leaky_relu.cc | 39 +++++++++++++++++++ src/07onnx/src/operators/leaky_relu.hh | 25 ++++++++++++ src/07onnx/src/operators/simple_unary.cc | 9 ++++- src/07onnx/src/operators/simple_unary.hh | 1 + 10 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 src/05computation/include/computation/operators/leaky_relu.h create mode 100644 src/05computation/src/operators/leaky_relu.cc create mode 100644 src/07onnx/src/operators/leaky_relu.cc create mode 100644 src/07onnx/src/operators/leaky_relu.hh diff --git a/scripts/onnx/to_onnx.py b/scripts/onnx/to_onnx.py index a8e1a538..979cc646 100644 --- a/scripts/onnx/to_onnx.py +++ b/scripts/onnx/to_onnx.py @@ -355,6 +355,18 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ), [], ) + if self.type == "LeakyRelu": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + domain=DEFAULT_DOMAIN, + alpha=float(self.meta), + ), + [], + ) raise ValueError(f"Unsupported operator {self.type}") diff --git a/src/04kernel/include/kernel/collectors/simple_unary.h b/src/04kernel/include/kernel/collectors/simple_unary.h index f19b0358..445437e5 100644 --- a/src/04kernel/include/kernel/collectors/simple_unary.h +++ b/src/04kernel/include/kernel/collectors/simple_unary.h @@ -27,6 +27,7 @@ namespace refactor::kernel { Not, HardSwish, Exp, + Floor, }; std::string_view unaryName(SimpleUnaryType type); diff --git a/src/05computation/include/computation/operators/leaky_relu.h b/src/05computation/include/computation/operators/leaky_relu.h new file mode 100644 index 00000000..56b9cf03 --- /dev/null +++ b/src/05computation/include/computation/operators/leaky_relu.h @@ -0,0 +1,23 @@ +#ifndef COMPUTATION_LEAKY_RELU_H +#define COMPUTATION_LEAKY_RELU_H + +#include "../operator.h" + +namespace refactor::computation { + + struct LeakyRelu final : public Operator { + float alpha; + + constexpr LeakyRelu(float alpha_) noexcept + : Operator(), alpha(alpha_){}; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_LEAKY_RELU_H diff --git a/src/05computation/src/operators/leaky_relu.cc b/src/05computation/src/operators/leaky_relu.cc new file mode 100644 index 00000000..a6347329 --- /dev/null +++ b/src/05computation/src/operators/leaky_relu.cc @@ -0,0 +1,20 @@ +#include "computation/operators/leaky_relu.h" + +namespace refactor::computation { + using Op = LeakyRelu; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "LeakyRelu"; } + + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + return nullptr; + } + auto Op::serialize() const noexcept -> std::string { + return fmt::format("{}({})", name(), alpha); + } + +}// namespace refactor::computation diff --git a/src/05computation/src/operators/simple_unary.cc b/src/05computation/src/operators/simple_unary.cc index a37fe1c7..28301158 100644 --- a/src/05computation/src/operators/simple_unary.cc +++ b/src/05computation/src/operators/simple_unary.cc @@ -89,6 +89,10 @@ namespace refactor::computation { static uint8_t ID = 21; return reinterpret_cast(&ID); } + case SimpleUnaryType::Floor: { + static uint8_t ID = 22; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -140,6 +144,8 @@ namespace refactor::computation { return "HardSwish"; case SimpleUnaryType::Exp: return "Exp"; + case SimpleUnaryType::Floor: + return "Floor"; default: UNREACHABLE(); } diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index db50dcf4..aab62eb4 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -20,6 +20,7 @@ #include "operators/global_pool.hh" #include "operators/hard_sigmoid.hh" #include "operators/layernorm.hh" +#include "operators/leaky_relu.hh" #include "operators/mat_mul.hh" #include "operators/mat_mul_integer.hh" #include "operators/pad.hh" @@ -123,6 +124,7 @@ namespace refactor::onnx { REGISTER(Identity , SimpleUnary ); REGISTER(HardSwish , SimpleUnary ); REGISTER(Exp , SimpleUnary ); + REGISTER(Floor , SimpleUnary ); REGISTER(Slice , Slice ); REGISTER(Softmax , Softmax ); REGISTER(Split , Split ); @@ -135,6 +137,7 @@ namespace refactor::onnx { REGISTER(Pad , Pad ); REGISTER(DepthToSpace , DepthToSpace ); REGISTER(LayerNormalization , Layernorm ); + REGISTER(LeakyRelu , LeakyRelu ); // clang-format on #undef REGISTER } diff --git a/src/07onnx/src/operators/leaky_relu.cc b/src/07onnx/src/operators/leaky_relu.cc new file mode 100644 index 00000000..1b6a598d --- /dev/null +++ b/src/07onnx/src/operators/leaky_relu.cc @@ -0,0 +1,39 @@ +#include "leaky_relu.hh" +#include "common.h" +#include "computation/operators/leaky_relu.h" +#include + +namespace refactor::onnx { + using Op = LeakyRelu; + + Op::LeakyRelu(Float alpha) + : Operator(), alpha(alpha) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto alpha = attributes.getOrInsert("alpha", {0.01f}).float_(); + return OpBox(std::make_unique(alpha)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::LeakyRelu"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + EXPECT_SIZE(1) + auto dataType = inputs[0].dataType; + if (!dataType.isFloat()) { + return Err(InferError(ERROR_MSG("Data type not support"))); + } + auto ans = Tensor::share(dataType, inputs[0].shape, extractDependency(inputs)); + return Ok(Tensors{std::move(ans)}); + } + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::LeakyRelu; + return std::make_unique(alpha); + } + + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/leaky_relu.hh b/src/07onnx/src/operators/leaky_relu.hh new file mode 100644 index 00000000..96c4770d --- /dev/null +++ b/src/07onnx/src/operators/leaky_relu.hh @@ -0,0 +1,25 @@ +#ifndef ONNX_LEAKY_RELU_HH +#define ONNX_LEAKY_RELU_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct LeakyRelu final : public Operator { + Float alpha; + + explicit LeakyRelu(Float); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_LEAKY_RELU_HH diff --git a/src/07onnx/src/operators/simple_unary.cc b/src/07onnx/src/operators/simple_unary.cc index 9192c916..e26c296e 100644 --- a/src/07onnx/src/operators/simple_unary.cc +++ b/src/07onnx/src/operators/simple_unary.cc @@ -39,6 +39,7 @@ namespace refactor::onnx { opType == "onnx::Identity"? Ty::Identity: opType == "onnx::HardSwish" ? Ty::HardSwish : opType == "onnx::Exp" ? Ty::Exp : + opType == "onnx::Floor" ? Ty::Floor : UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType); // clang-format on @@ -139,6 +140,10 @@ namespace refactor::onnx { static uint8_t ID = 23; return reinterpret_cast(&ID); } + case Ty::Floor: { + static uint8_t ID = 24; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -171,6 +176,7 @@ namespace refactor::onnx { case Ty::Identity : return "onnx::Identity"; case Ty::HardSwish : return "onnx::HardSwish"; case Ty::Exp : return "onnx::Exp"; + case Ty::Floor : return "onnx::Floor"; default: UNREACHABLE(); } // clang-format on @@ -200,7 +206,7 @@ namespace refactor::onnx { Ty::Cos, Ty::Cosh, Ty::Sin, Ty::Sinh, Ty::Tan, Ty::HardSwish}, - {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp}, + {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp, Ty::Floor}, {Ty::Neg}, {Ty::Identity}}; if (SET[0].contains(type)) { @@ -301,6 +307,7 @@ namespace refactor::onnx { case Ty::Identity : return std::make_unique(); case Ty::HardSwish : type_ = Ty_::HardSwish ; break; case Ty::Exp : type_ = Ty_::Exp ; break; + case Ty::Floor : type_ = Ty_::Floor ; break; default: UNREACHABLE(); } // clang-format on diff --git a/src/07onnx/src/operators/simple_unary.hh b/src/07onnx/src/operators/simple_unary.hh index 93e3f116..0d142be1 100644 --- a/src/07onnx/src/operators/simple_unary.hh +++ b/src/07onnx/src/operators/simple_unary.hh @@ -30,6 +30,7 @@ namespace refactor::onnx { Sigmoid, Tan, Tanh, + Floor, }; struct SimpleUnary final : public Operator {