Skip to content

Commit

Permalink
feat:support opt layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Apr 26, 2024
1 parent 7f5a617 commit f51335d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type in ["Relu", "Tanh"]:
if self.type in ["Relu", "Tanh", "Erf", "Max"]:
return (
make_node(
self.type,
Expand Down Expand Up @@ -448,7 +448,7 @@ def main():
save_model(
model,
outputfile,
save_as_external_data=True,
# save_as_external_data=True,
all_tensors_to_one_file=True,
)
check_model(outputfile)
Expand Down
5 changes: 3 additions & 2 deletions src/05computation/include/computation/pass/layernorm_fuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "../graph.h"
#include "computation/operators/layernorm.h"
#include "computation/operators/reduce.h"
#include "computation/operators/reshape.h"
#include "computation/operators/simple_binary.h"
#include "computation/operators/simple_unary.h"
#include "computation/pass/converter.h"
Expand All @@ -20,13 +21,13 @@ namespace refactor::computation {
continue;
}
size_t optype = opMatch->info().op->opTypeId();
if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add) && optype != Reshape::typeId()) {
continue;
}
if (opMatch->successors().size() < 2) {
continue;
}
auto input = opMatch->inputs()[0]->info().tensor;
auto input = opMatch->outputs()[0]->info().tensor;
auto targets = opMatch->outputs()[0]->targets();
auto ReduceMeanOp = *targets.begin();
auto SubOp1 = *(std::next(targets.begin()));
Expand Down

0 comments on commit f51335d

Please sign in to comment.