Skip to content

Commit

Permalink
feat: support leakyrelu and floor op
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jul 25, 2024
1 parent 908c795 commit 10e9832
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 1 deletion.
12 changes: 12 additions & 0 deletions scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "LeakyRelu":
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
alpha=float(self.meta),
),
[],
)

raise ValueError(f"Unsupported operator {self.type}")

Expand Down
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/collectors/simple_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace refactor::kernel {
Not,
HardSwish,
Exp,
Floor,
};

std::string_view unaryName(SimpleUnaryType type);
Expand Down
23 changes: 23 additions & 0 deletions src/05computation/include/computation/operators/leaky_relu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef COMPUTATION_LEAKY_RELU_H
#define COMPUTATION_LEAKY_RELU_H

#include "../operator.h"

namespace refactor::computation {

struct LeakyRelu final : public Operator {
float alpha;

constexpr LeakyRelu(float alpha_) noexcept
: Operator(), alpha(alpha_){};

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_LEAKY_RELU_H
20 changes: 20 additions & 0 deletions src/05computation/src/operators/leaky_relu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "computation/operators/leaky_relu.h"

namespace refactor::computation {
using Op = LeakyRelu;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "LeakyRelu"; }

auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
return nullptr;
}
auto Op::serialize() const noexcept -> std::string {
return fmt::format("{}({})", name(), alpha);
}

}// namespace refactor::computation
6 changes: 6 additions & 0 deletions src/05computation/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ namespace refactor::computation {
static uint8_t ID = 21;
return reinterpret_cast<size_t>(&ID);
}
case SimpleUnaryType::Floor: {
static uint8_t ID = 22;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -140,6 +144,8 @@ namespace refactor::computation {
return "HardSwish";
case SimpleUnaryType::Exp:
return "Exp";
case SimpleUnaryType::Floor:
return "Floor";
default:
UNREACHABLE();
}
Expand Down
3 changes: 3 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "operators/global_pool.hh"
#include "operators/hard_sigmoid.hh"
#include "operators/layernorm.hh"
#include "operators/leaky_relu.hh"
#include "operators/mat_mul.hh"
#include "operators/mat_mul_integer.hh"
#include "operators/pad.hh"
Expand Down Expand Up @@ -123,6 +124,7 @@ namespace refactor::onnx {
REGISTER(Identity , SimpleUnary );
REGISTER(HardSwish , SimpleUnary );
REGISTER(Exp , SimpleUnary );
REGISTER(Floor , SimpleUnary );
REGISTER(Slice , Slice );
REGISTER(Softmax , Softmax );
REGISTER(Split , Split );
Expand All @@ -135,6 +137,7 @@ namespace refactor::onnx {
REGISTER(Pad , Pad );
REGISTER(DepthToSpace , DepthToSpace );
REGISTER(LayerNormalization , Layernorm );
REGISTER(LeakyRelu , LeakyRelu );
// clang-format on
#undef REGISTER
}
Expand Down
39 changes: 39 additions & 0 deletions src/07onnx/src/operators/leaky_relu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "leaky_relu.hh"
#include "common.h"
#include "computation/operators/leaky_relu.h"
#include <execution>

namespace refactor::onnx {
using Op = LeakyRelu;

Op::LeakyRelu(Float alpha)
: Operator(), alpha(alpha) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto alpha = attributes.getOrInsert("alpha", {0.01f}).float_();
return OpBox(std::make_unique<Op>(alpha));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "onnx::LeakyRelu"; }

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
EXPECT_SIZE(1)
auto dataType = inputs[0].dataType;
if (!dataType.isFloat()) {
return Err(InferError(ERROR_MSG("Data type not support")));
}
auto ans = Tensor::share(dataType, inputs[0].shape, extractDependency(inputs));
return Ok(Tensors{std::move(ans)});
}
auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::LeakyRelu;
return std::make_unique<Op_>(alpha);
}


}// namespace refactor::onnx
25 changes: 25 additions & 0 deletions src/07onnx/src/operators/leaky_relu.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef ONNX_LEAKY_RELU_HH
#define ONNX_LEAKY_RELU_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct LeakyRelu final : public Operator {
Float alpha;

explicit LeakyRelu(Float);

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_LEAKY_RELU_HH
9 changes: 8 additions & 1 deletion src/07onnx/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace refactor::onnx {
opType == "onnx::Identity"? Ty::Identity:
opType == "onnx::HardSwish" ? Ty::HardSwish :
opType == "onnx::Exp" ? Ty::Exp :
opType == "onnx::Floor" ? Ty::Floor :
UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType);
// clang-format on

Expand Down Expand Up @@ -139,6 +140,10 @@ namespace refactor::onnx {
static uint8_t ID = 23;
return reinterpret_cast<size_t>(&ID);
}
case Ty::Floor: {
static uint8_t ID = 24;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -171,6 +176,7 @@ namespace refactor::onnx {
case Ty::Identity : return "onnx::Identity";
case Ty::HardSwish : return "onnx::HardSwish";
case Ty::Exp : return "onnx::Exp";
case Ty::Floor : return "onnx::Floor";
default: UNREACHABLE();
}
// clang-format on
Expand Down Expand Up @@ -200,7 +206,7 @@ namespace refactor::onnx {
Ty::Cos, Ty::Cosh,
Ty::Sin, Ty::Sinh,
Ty::Tan, Ty::HardSwish},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp, Ty::Floor},
{Ty::Neg},
{Ty::Identity}};
if (SET[0].contains(type)) {
Expand Down Expand Up @@ -301,6 +307,7 @@ namespace refactor::onnx {
case Ty::Identity : return std::make_unique<computation::Identity>();
case Ty::HardSwish : type_ = Ty_::HardSwish ; break;
case Ty::Exp : type_ = Ty_::Exp ; break;
case Ty::Floor : type_ = Ty_::Floor ; break;
default: UNREACHABLE();
}
// clang-format on
Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators/simple_unary.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace refactor::onnx {
Sigmoid,
Tan,
Tanh,
Floor,
};

struct SimpleUnary final : public Operator {
Expand Down

0 comments on commit 10e9832

Please sign in to comment.