Skip to content

Commit

Permalink
feat: support gelu fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Apr 26, 2024
1 parent 9ab335a commit 908c795
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 0 deletions.
11 changes: 11 additions & 0 deletions scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,17 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "Gelu":
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain="refactor",
),
[],
)

raise ValueError(f"Unsupported operator {self.type}")

Expand Down
21 changes: 21 additions & 0 deletions src/05computation/include/computation/operators/gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef COMPUTATION_GELU_H
#define COMPUTATION_GELU_H

#include "../operator.h"

namespace refactor::computation {

struct Gelu final : public Operator {

constexpr explicit Gelu() noexcept
: Operator() {}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
// kernel::CollectorBox candidateKernels(Target) const final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation
#endif//COMPUTATION_GELU_H
97 changes: 97 additions & 0 deletions src/05computation/include/computation/pass/gelu_fuse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#ifndef COMPUTATION_GELU_FUSE_H
#define COMPUTATION_GELU_FUSE_H

#include "../graph.h"
#include "computation/operators/gelu.h"
#include "computation/operators/reshape.h"
#include "computation/operators/simple_binary.h"
#include "computation/operators/simple_unary.h"
#include "computation/pass/converter.h"

namespace refactor::computation {
class GeluFuse : public Converter {
public:
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override {
auto nodesList = g->internal().nodes();
size_t count = 0;
for (auto opMatch : nodesList) {
if (opMatch->info().op == nullptr) {
continue;
}
size_t optype = opMatch->info().op->opTypeId();
if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add) &&
optype != Reshape::typeId()) {
continue;
}
auto input = opMatch->outputs()[0]->info().tensor;
auto targets = opMatch->outputs()[0]->targets();
if (opMatch->successors().size() >= 3) {

} else if (opMatch->successors().size() >= 2) {
// op1 is Div op2 is Mul
auto op1 = *targets.begin();
auto op2 = *(std::next(targets.begin()));
if (op1 == nullptr || op2 == nullptr ||
op1->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Div) ||
op2->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) {
continue;
}
if (op1->successors().size() != 1 || op2->successors().size() != 1) {
continue;
}
auto ErfOp = *(op1->outputs()[0]->targets().begin());
auto MulOp = *(op2->outputs()[0]->targets().begin());
if (ErfOp == nullptr || MulOp == nullptr ||
ErfOp->info().op->opTypeId() != SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Erf) ||
MulOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) {
continue;
}
if (auto alpha = MulOp->inputs()[1]->info().tensor->data; alpha) {
float alphaVal = *alpha->get<float>();
if (alphaVal != 0.5f) {
continue;
}
} else {
continue;
}
if (ErfOp->successors().size() != 1) {
continue;
}
auto AddOp = *(ErfOp->outputs()[0]->targets().begin());
if (AddOp == nullptr || AddOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
continue;
}
if (auto beta = AddOp->inputs()[1]->info().tensor->data; beta) {
float betaVal = *beta->get<float>();
if (betaVal != 1.0f) {
continue;
}
} else {
continue;
}
if (AddOp->successors().size() != 1 || *(AddOp->outputs()[0]->targets().begin()) != op2) {
continue;
}
// replace
auto geluOp = g->internal().pushNode(
{std::make_unique<Gelu>(), fmt::format("Gelu_{}", count)},
{g->internal().shareEdge({Tensor::share(input->dataType, input->shape), fmt::format("Gelu_{}_out", count)})});
geluOp->connect(0, opMatch->outputs()[0]);
if (MulOp->outputs()[0]->targets().size() == 0) {
g->internal().replaceOutput(MulOp->outputs()[0], geluOp->outputs()[0]);
} else {
for (auto node : MulOp->outputs()[0]->targets()) {
auto it = std::find(node->inputs().begin(), node->inputs().end(), MulOp->outputs()[0]);
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], geluOp->outputs()[0]);
}
}
count++;
g->internal().cleanup();
}
}
return true;
};
};
}// namespace refactor::computation

#endif//COMPUTATION_GELU_FUSE_H
2 changes: 2 additions & 0 deletions src/05computation/include/computation/pass_register.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define COMPUTATION_PASS_REGISTER_H
#include "pass/conv_to_matmul.h"
#include "pass/converter.h"
#include "pass/gelu_fuse.h"
#include "pass/layernorm_fuse.h"
#include "pass/matmul_transpose.h"

Expand All @@ -12,6 +13,7 @@ namespace refactor::computation {
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse)
REGISTER(ConvToMatmul, ConvToMatmul)
REGISTER(LayernormFuse, LayernormFuse)
REGISTER(GeluFuse, GeluFuse)
};


Expand Down
1 change: 1 addition & 0 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ namespace refactor::computation {
auto graphMutant = GraphMutant(*this);
std::vector<std::string_view> passes = {
"LayernormFuse",
"GeluFuse",
// "MatMulTransposeFuse",
// "ConvToMatmul",
};
Expand Down
16 changes: 16 additions & 0 deletions src/05computation/src/operators/gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "computation/operators/gelu.h"

namespace refactor::computation {
using Op = Gelu;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "Gelu"; }
auto Op::serialize() const noexcept -> std::string {
return fmt::format(("{}()"), name());
}

}// namespace refactor::computation

0 comments on commit 908c795

Please sign in to comment.