Skip to content

Commit

Permalink
feat:支持layernorm融合算子
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Mar 22, 2024
1 parent 28f6fb0 commit c237660
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 7 deletions.
18 changes: 15 additions & 3 deletions scripts/onnx/make_serialize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from refactor_graph.onnx import make_compiler
from onnx import load
import argparse
from onnx.external_data_helper import load_external_data_for_model


def parse_args():
parser = argparse.ArgumentParser(
Expand All @@ -9,17 +11,27 @@ def parse_args():
parser.add_argument(
"--model", type=str, required=True, help="Path to the model file file."
)
parser.add_argument("--output", type=str, default="./", help="Path to save the output file.")
parser.add_argument(
"--output", type=str, default="./", help="Path to save the output file."
)
args = parser.parse_args()
return (
args.model,
args.output,
)


def main():
model_path, output_path = parse_args()
compiler = make_compiler(load(model_path))
model = load(model_path)
# model = load(model_path, load_external_data=False)
# load_external_data_for_model(
# model,
# "/home/zhangyunze/workspace/RefactorGraph/scripts/onnx/bert_bs1.pb",
# )
compiler = make_compiler(model)
compiler.serialize(output_path)


if __name__ == "__main__":
main()
main()
32 changes: 28 additions & 4 deletions scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "Relu":
if self.type in ["Relu", "Tanh"]:
return (
make_node(
self.type,
Expand Down Expand Up @@ -166,6 +166,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
"Log",
"Neg",
"Sigmoid",
"Where",
]:
return (
make_node(
Expand Down Expand Up @@ -235,14 +236,14 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[shape],
)
if self.type in ["Gather", "Concat", "Softmax"]:
if self.type in ["Gather", "Concat", "Softmax", "Split"]:
meta = self.meta.split(b"/")
axis = int(meta[0])
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[self.topo.outputs[0]].name],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
axis=axis,
Expand All @@ -251,7 +252,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
)
if self.type == "ReduceMean":
meta = self.meta.split(b",")
keepDims = meta[2] == b"true"
keepDims = meta[2] == b" true"
axes = [int(x) for x in split_array(meta[0])]
return (
make_node(
Expand Down Expand Up @@ -315,6 +316,22 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "LayerNormalization":
meta = self.meta.split(b",")
epsilon = float(meta[0].split(b"=")[0].strip())
axis = int(meta[1])
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain="refactor",
epsilon=epsilon,
axis=axis,
),
[],
)

raise ValueError(f"Unsupported operator {self.type}")

Expand Down Expand Up @@ -391,6 +408,13 @@ def main():
for t in (tensors[i] for i in graph.outputs)
],
initializer,
value_info=[
make_tensor_value_info(t.name, t.dt, t.shape)
for t in tensors
if t.size == 0
and t.name not in graph.inputs
and t.name not in graph.outputs
],
)
# model = make_model(
# graph, opset_imports=[make_opsetid(domain="", version=13)]
Expand Down
24 changes: 24 additions & 0 deletions src/05computation/include/computation/operators/layernorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef COMPUTATION_LAYER_NORMALIZATION_H
#define COMPUTATION_LAYER_NORMALIZATION_H

#include "../operator.h"

namespace refactor::computation {

struct LayerNormalization final : public Operator {
float epsilon;
int axis;

constexpr explicit LayerNormalization(float epsilon_, int axis_) noexcept
: Operator(), epsilon(epsilon_), axis(axis_) {}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
// kernel::CollectorBox candidateKernels(Target) const final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_LAYER_NORMALIZATION_H
130 changes: 130 additions & 0 deletions src/05computation/include/computation/pass/layernorm_fuse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#ifndef COMPUTATION_LAYERNORM_FUSE_H
#define COMPUTATION_LAYERNORM_FUSE_H

#include "../graph.h"
#include "computation/operators/layernorm.h"
#include "computation/operators/reduce.h"
#include "computation/operators/simple_binary.h"
#include "computation/operators/simple_unary.h"
#include "computation/pass/converter.h"

namespace refactor::computation {

class LayernormFuse : public Converter {
public:
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override {
auto nodesList = g->internal().nodes();
size_t count = 0;
for (auto opMatch : nodesList) {
if (opMatch->info().op == nullptr) {
continue;
}
size_t optype = opMatch->info().op->opTypeId();
if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
continue;
}
if (opMatch->successors().size() < 2) {
continue;
}
auto input = opMatch->inputs()[0]->info().tensor;
auto targets = opMatch->outputs()[0]->targets();
auto ReduceMeanOp = *targets.begin();
auto SubOp1 = *(std::next(targets.begin()));
if (ReduceMeanOp == nullptr || SubOp1 == nullptr ||
ReduceMeanOp->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) ||
SubOp1->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Sub)) {
continue;
}
auto reduceOp = dynamic_cast<Reduce *>(ReduceMeanOp->info().op.get());
auto axes = reduceOp->axes;
if (axes.size() != 1) {
continue;
}
auto keepDims = reduceOp->keepDims;
if (ReduceMeanOp->successors().size() != 1 || *(ReduceMeanOp->outputs()[0]->targets().begin()) != SubOp1) {
continue;
}
if (SubOp1->successors().size() != 2) {
continue;
}
auto targets1 = SubOp1->outputs()[0]->targets();
auto PowOp = *targets1.begin();
auto DivOp = *(std::next(targets1.begin()));
if (PowOp == nullptr || DivOp == nullptr ||
PowOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Pow) ||
DivOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Div)) {
continue;
}
if (PowOp->successors().size() != 1 || DivOp->successors().size() != 1) {
continue;
}
auto pow_value = PowOp->inputs()[1]->info().tensor->data;
if (!pow_value || *pow_value->get<float>() != 2.0f) {
continue;
}
auto ReduceMeanOp1 = *(PowOp->outputs()[0]->targets().begin());
auto MulOp = *(DivOp->outputs()[0]->targets().begin());
if (ReduceMeanOp1 == nullptr || MulOp == nullptr ||
ReduceMeanOp1->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) ||
MulOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) {
continue;
}
auto reduce1Op = dynamic_cast<Reduce *>(ReduceMeanOp1->info().op.get());
auto axes1 = reduce1Op->axes;
if (axes != axes1) {
continue;
}
if (auto keepDims1 = reduce1Op->keepDims; keepDims != keepDims1) {
continue;
}
if (MulOp->successors().size() != 1 || ReduceMeanOp1->successors().size() != 1) {
continue;
}
auto AddOp = *(ReduceMeanOp1->outputs()[0]->targets().begin());
auto AddOp2 = *(MulOp->outputs()[0]->targets().begin());
if (AddOp == nullptr || AddOp2 == nullptr ||
AddOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add) ||
AddOp2->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
continue;
}
if (AddOp->successors().size() != 1) {
continue;
}
auto SqrtOp = *(AddOp->outputs()[0]->targets().begin());
if (SqrtOp == nullptr || SqrtOp->info().op->opTypeId() != SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Sqrt)) {
continue;
}
if (SqrtOp->successors().size() != 1 || *(SqrtOp->outputs()[0]->targets().begin()) != DivOp) {
continue;
}
// start replace with LayernormOp
float epsilon = 0.0;
if (auto t = AddOp->inputs()[1]->info().tensor->data; t) {
epsilon = *t->get<float>();
}
int axis = axes[0];
auto layernormOp = g->internal().pushNode(
{std::make_unique<LayerNormalization>(epsilon, axis), fmt::format("Layernorm", count)},
{g->internal().shareEdge({Tensor::share(input->dataType, input->shape), fmt::format("Layernorm_{}_out", count)})});
layernormOp->connect(0, opMatch->outputs()[0]);
layernormOp->connect(1, MulOp->inputs()[1]);
layernormOp->connect(2, AddOp2->inputs()[1]);
if (AddOp2->outputs()[0]->targets().size() == 0) {//global output
g->internal().replaceOutput(AddOp2->outputs()[0], layernormOp->outputs()[0]);
} else {
for (auto node : AddOp2->outputs()[0]->targets()) {
auto it = std::find(node->inputs().begin(), node->inputs().end(), AddOp2->outputs()[0]);
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], layernormOp->outputs()[0]);
}
}
count++;
g->internal().cleanup();
}
return true;
};
};


}// namespace refactor::computation

#endif// COMPUTATION_LAYERNORM_FUSE_H
2 changes: 2 additions & 0 deletions src/05computation/include/computation/pass_register.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define COMPUTATION_PASS_REGISTER_H
#include "pass/conv_to_matmul.h"
#include "pass/converter.h"
#include "pass/layernorm_fuse.h"
#include "pass/matmul_transpose.h"

namespace refactor::computation {
Expand All @@ -10,6 +11,7 @@ namespace refactor::computation {
#define REGISTER(PASS, NAME) static ConverterRegister<PASS> NAME("" #NAME);
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse)
REGISTER(ConvToMatmul, ConvToMatmul)
REGISTER(LayernormFuse, LayernormFuse)
};


Expand Down
1 change: 1 addition & 0 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ namespace refactor::computation {
void Graph::optimize() {
auto graphMutant = GraphMutant(*this);
std::vector<std::string_view> passes = {
"LayernormFuse",
// "MatMulTransposeFuse",
// "ConvToMatmul",
};
Expand Down
22 changes: 22 additions & 0 deletions src/05computation/src/operators/layernorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "computation/operators/layernorm.h"

namespace refactor::computation {
using Op = LayerNormalization;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "LayerNormalization"; }
auto Op::serialize() const noexcept -> std::string {
union code {
float f;
int32_t i;
};
return fmt::format(("{}({:e}={:#010x},{})"),
name(), epsilon,
code{epsilon}.i, axis);
}

}// namespace refactor::computation
Loading

0 comments on commit c237660

Please sign in to comment.