diff --git a/scripts/onnx/to_onnx.py b/scripts/onnx/to_onnx.py index 45fd0f12..a8e1a538 100644 --- a/scripts/onnx/to_onnx.py +++ b/scripts/onnx/to_onnx.py @@ -344,6 +344,17 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ), [], ) + if self.type == "Gelu": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + domain="refactor", + ), + [], + ) raise ValueError(f"Unsupported operator {self.type}") diff --git a/src/05computation/include/computation/operators/gelu.h b/src/05computation/include/computation/operators/gelu.h new file mode 100644 index 00000000..fb5d9b08 --- /dev/null +++ b/src/05computation/include/computation/operators/gelu.h @@ -0,0 +1,21 @@ +#ifndef COMPUTATION_GELU_H +#define COMPUTATION_GELU_H + +#include "../operator.h" + +namespace refactor::computation { + + struct Gelu final : public Operator { + + constexpr explicit Gelu() noexcept + : Operator() {} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + // kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation +#endif//COMPUTATION_GELU_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/gelu_fuse.h b/src/05computation/include/computation/pass/gelu_fuse.h new file mode 100644 index 00000000..f6fe3c2c --- /dev/null +++ b/src/05computation/include/computation/pass/gelu_fuse.h @@ -0,0 +1,97 @@ +#ifndef COMPUTATION_GELU_FUSE_H +#define COMPUTATION_GELU_FUSE_H + +#include "../graph.h" +#include "computation/operators/gelu.h" +#include "computation/operators/reshape.h" +#include "computation/operators/simple_binary.h" +#include "computation/operators/simple_unary.h" +#include "computation/pass/converter.h" + +namespace refactor::computation { + class GeluFuse : public Converter { + public: + virtual bool execute(const std::shared_ptr &g) const override { + auto nodesList = g->internal().nodes(); + size_t count = 0; + for (auto opMatch : nodesList) { + if (opMatch->info().op == nullptr) { + continue; + } + size_t optype = opMatch->info().op->opTypeId(); + if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add) && + optype != Reshape::typeId()) { + continue; + } + auto input = opMatch->outputs()[0]->info().tensor; + auto targets = opMatch->outputs()[0]->targets(); + if (opMatch->successors().size() >= 3) { + + } else if (opMatch->successors().size() >= 2) { + // op1 is Div op2 is Mul + auto op1 = *targets.begin(); + auto op2 = *(std::next(targets.begin())); + if (op1 == nullptr || op2 == nullptr || + op1->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Div) || + op2->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) { + continue; + } + if (op1->successors().size() != 1 || op2->successors().size() != 1) { + continue; + } + auto ErfOp = *(op1->outputs()[0]->targets().begin()); + auto MulOp = *(op2->outputs()[0]->targets().begin()); + if (ErfOp == nullptr || MulOp == nullptr || + ErfOp->info().op->opTypeId() != SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Erf) || + MulOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) { + continue; + } + if (auto alpha = MulOp->inputs()[1]->info().tensor->data; alpha) { + float alphaVal = *alpha->get(); + if (alphaVal != 0.5f) { + continue; + } + } else { + continue; + } + if (ErfOp->successors().size() != 1) { + continue; + } + auto AddOp = *(ErfOp->outputs()[0]->targets().begin()); + if (AddOp == nullptr || AddOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) { + continue; + } + if (auto beta = AddOp->inputs()[1]->info().tensor->data; beta) { + float betaVal = *beta->get(); + if (betaVal != 1.0f) { + continue; + } + } else { + continue; + } + if (AddOp->successors().size() != 1 || *(AddOp->outputs()[0]->targets().begin()) != op2) { + continue; + } + // replace + auto geluOp = g->internal().pushNode( + {std::make_unique(), fmt::format("Gelu_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, input->shape), fmt::format("Gelu_{}_out", count)})}); + geluOp->connect(0, opMatch->outputs()[0]); + if (MulOp->outputs()[0]->targets().size() == 0) { + g->internal().replaceOutput(MulOp->outputs()[0], geluOp->outputs()[0]); + } else { + for (auto node : MulOp->outputs()[0]->targets()) { + auto it = std::find(node->inputs().begin(), node->inputs().end(), MulOp->outputs()[0]); + node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], geluOp->outputs()[0]); + } + } + count++; + g->internal().cleanup(); + } + } + return true; + }; + }; +}// namespace refactor::computation + +#endif//COMPUTATION_GELU_FUSE_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass_register.h b/src/05computation/include/computation/pass_register.h index 6f883023..6ecacb3e 100644 --- a/src/05computation/include/computation/pass_register.h +++ b/src/05computation/include/computation/pass_register.h @@ -2,6 +2,7 @@ #define COMPUTATION_PASS_REGISTER_H #include "pass/conv_to_matmul.h" #include "pass/converter.h" +#include "pass/gelu_fuse.h" #include "pass/layernorm_fuse.h" #include "pass/matmul_transpose.h" @@ -12,6 +13,7 @@ namespace refactor::computation { REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) REGISTER(ConvToMatmul, ConvToMatmul) REGISTER(LayernormFuse, LayernormFuse) + REGISTER(GeluFuse, GeluFuse) }; diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index caf4761b..0fcd2000 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -221,6 +221,7 @@ namespace refactor::computation { auto graphMutant = GraphMutant(*this); std::vector passes = { "LayernormFuse", + "GeluFuse", // "MatMulTransposeFuse", // "ConvToMatmul", }; diff --git a/src/05computation/src/operators/gelu.cc b/src/05computation/src/operators/gelu.cc new file mode 100644 index 00000000..14483997 --- /dev/null +++ b/src/05computation/src/operators/gelu.cc @@ -0,0 +1,16 @@ +#include "computation/operators/gelu.h" + +namespace refactor::computation { + using Op = Gelu; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "Gelu"; } + auto Op::serialize() const noexcept -> std::string { + return fmt::format(("{}()"), name()); + } + +}// namespace refactor::computation