Skip to content

Commit

Permalink
fix: 修改to_onnx脚本,能够正确导出llama模型
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Mar 18, 2024
1 parent bf5402e commit 28f6fb0
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 18 deletions.
160 changes: 149 additions & 11 deletions scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mmap
import re
import argparse
from onnx import TensorProto, NodeProto, save_model
from onnx.helper import (
Expand All @@ -10,18 +11,23 @@
make_opsetid,
)
from onnx.checker import check_model


class Topo:
def __init__(self, bytes: bytes):
list = bytes.strip().split(b"<-")
self.inputs = [int(s.strip(b"%")) for s in list[1].split()]
self.outputs = [int(s.strip(b"%")) for s in list[0].split()]

def __str__(self) -> str:
return f"{self.inputs} <- {self.outputs}"



class Tensor:
def __init__(self, bytes_: bytes):
list = bytes_.split(b"\t")
self.name = str(list[1].strip(), "utf-8")

def map_dt(dt: bytes) -> TensorProto.DataType:
match dt:
case b"F32":
Expand Down Expand Up @@ -58,6 +64,7 @@ def map_dt(dt: bytes) -> TensorProto.DataType:
return TensorProto.BFLOAT16
case _:
return TensorProto.UNDEFINED

self.dt = map_dt(list[2].strip())
layout = list[3].strip()
if layout != b"NCHW" and layout != b"ELSE":
Expand All @@ -66,9 +73,11 @@ def map_dt(dt: bytes) -> TensorProto.DataType:
self.offset = int(range[0], 0)
self.size = int(range[1], 0)
self.shape = [int(s) for s in split_array(list[5])]

def __str__(self) -> str:
return f"{self.name} (dt = {self.dt}) {self.shape} {self.offset}..{self.offset + self.size}"



class Operator:
def __init__(self, bytes: bytes):
list = bytes.split(b"\t")
Expand All @@ -78,16 +87,20 @@ def __init__(self, bytes: bytes):
list = list[1].rsplit(b")", 1)
self.meta = list[0].strip()
self.topo = Topo(list[1])

def __str__(self) -> str:
return f"{self.type}: {self.name}, meta = {self.meta}, topo = {self.topo}"

def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
DEFAULT_DOMAIN = ""
if self.type == "BatchNormalization":
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
epsilon=float(self.meta.split(b"=")[0]),
),
[],
Expand All @@ -101,6 +114,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
dilations=meta[0:rank],
strides=meta[rank : 2 * rank],
pads=meta[2 * rank : 4 * rank],
Expand All @@ -114,6 +128,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
),
[],
)
Expand All @@ -131,6 +146,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
ceil_mode=ceil_mode,
kernel_shape=kernel_shape,
dilations=meta[0:rank],
Expand All @@ -139,13 +155,25 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "Add":
if self.type in [
"Add",
"Pow",
"Sqrt",
"Div",
"Mul",
"Sub",
"Exp",
"Log",
"Neg",
"Sigmoid",
]:
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
),
[],
)
Expand All @@ -156,6 +184,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
),
[],
)
Expand All @@ -172,6 +201,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
alpha=alpha,
beta=beta,
transA=transA,
Expand All @@ -186,6 +216,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
),
[],
)
Expand All @@ -200,11 +231,94 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
[tensors[self.topo.inputs[0]].name, shape_name],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain=DEFAULT_DOMAIN,
),
[shape],
)
if self.type in ["Gather", "Concat", "Softmax"]:
meta = self.meta.split(b"/")
axis = int(meta[0])
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[self.topo.outputs[0]].name],
self.name,
domain=DEFAULT_DOMAIN,
axis=axis,
),
[],
)
if self.type == "ReduceMean":
meta = self.meta.split(b",")
keepDims = meta[2] == b"true"
axes = [int(x) for x in split_array(meta[0])]
return (
make_node(
self.type,
[tensors[self.topo.inputs[0]].name],
[tensors[self.topo.outputs[0]].name],
self.name,
domain=DEFAULT_DOMAIN,
axes=axes,
keepdims=keepDims,
),
[],
)
if self.type == "Transpose":
meta = [int(x) for x in split_array(self.meta)]
return (
make_node(
self.type,
[tensors[self.topo.inputs[0]].name],
[tensors[self.topo.outputs[0]].name],
self.name,
domain=DEFAULT_DOMAIN,
perm=meta,
),
[],
)
if self.type == "Slice":
# starts, ends, axes, steps = split_array_slice(self.meta)
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[self.topo.outputs[0]].name],
self.name,
domain=DEFAULT_DOMAIN,
),
[],
)
if self.type == "Cast":
to = int(tensors[self.topo.outputs[0]].dt)
return (
make_node(
self.type,
[tensors[self.topo.inputs[0]].name],
[tensors[self.topo.outputs[0]].name],
self.name,
domain=DEFAULT_DOMAIN,
to=to,
),
[],
)
if self.type == "RmsNormalization":
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[i].name for i in self.topo.outputs],
self.name,
domain="refactor",
epsilon=1e-5,
),
[],
)

raise ValueError(f"Unsupported operator {self.type}")


def parse_args():
parser = argparse.ArgumentParser(description="Analysis serialize file.")
parser.add_argument(
Expand All @@ -214,13 +328,23 @@ def parse_args():
help="Path to save the serialize output files.",
)
args = parser.parse_args()
return (
args.input
)
return args.input


def split_array(arr: bytes):
return (x for x in arr.strip().strip(b"[").strip(b"]").split())


def split_array_slice(arr: bytes):
meta_array = split_array(arr)
meta = [list(map(int, re.findall(r"\d+", x))) for x in meta_array]
starts = [int(x[0]) for x in meta]
ends = [int(x[0] + x[1] * x[2]) for x in meta]
axes = [x for x in range(len(meta))]
steps = [int(x[2]) for x in meta]
return starts, ends, axes, steps


def main():
path = parse_args()
info_path = path + "/graph.info"
Expand Down Expand Up @@ -268,10 +392,24 @@ def main():
],
initializer,
)
model = make_model(graph, opset_imports=[make_opsetid(
domain="", version=13)])
check_model(model)
save_model(model, outputfile)
# model = make_model(
# graph, opset_imports=[make_opsetid(domain="", version=13)]
# )
model = make_model(
graph,
opset_imports=[
make_opsetid(domain="refactor", version=1),
make_opsetid(domain="", version=13),
],
)
save_model(
model,
outputfile,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
check_model(outputfile)


if __name__ == "__main__":
main()
main()
4 changes: 2 additions & 2 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ namespace refactor::computation {
void Graph::optimize() {
auto graphMutant = GraphMutant(*this);
std::vector<std::string_view> passes = {
"MatMulTransposeFuse",
"ConvToMatmul",
// "MatMulTransposeFuse",
// "ConvToMatmul",
};
register_();//all pass insert
auto g = std::make_shared<GraphMutant>(graphMutant);
Expand Down
2 changes: 1 addition & 1 deletion src/05computation/src/operators/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace refactor::computation {
return std::make_unique<kernel::ReduceCollector>(target, type, axes);
}
auto Op::serialize() const noexcept -> std::string {
return fmt::format("{}({}/{}, {})",
return fmt::format("{}({}, {}, {})",
name(),
vec2str(axes),
rank,
Expand Down
9 changes: 7 additions & 2 deletions src/07onnx/src/operators/cast.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "computation/operators/cast.h"
#include "cast.hh"
#include "common.h"
#include "computation/operators/identity.h"
#include <execution>

namespace refactor::onnx {
Expand Down Expand Up @@ -30,7 +31,6 @@ namespace refactor::onnx {

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
EXPECT_SIZE(1)

auto const &input = inputs[0];
auto ans = Tensor::share(to, input.shape, extractDependency(inputs));
if (!options.shouldCalculate(inputs, {*ans})) {
Expand Down Expand Up @@ -116,8 +116,13 @@ namespace refactor::onnx {
}
return Ok(Tensors{std::move(ans)});
}
auto Op::lower(TensorRefs) const -> computation::OpBox {
auto Op::lower(TensorRefs inputs) const -> computation::OpBox {
using Op_ = computation::Cast;
auto const &input = inputs[0];
auto from = input.dataType;
if (from == to) {
return std::make_unique<computation::Identity>();
}
return std::make_unique<Op_>();
}

Expand Down
2 changes: 1 addition & 1 deletion src/07onnx/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace refactor::onnx {

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::MatMul;
return std::make_unique<Op_>(1.0, 1.0, false, false);
return std::make_unique<Op_>(1.0, 0.0, false, false);
}

}// namespace refactor::onnx
2 changes: 1 addition & 1 deletion src/09python_ffi/src/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ namespace refactor::python_ffi {
msg += ']';
RUNTIME_ERROR(std::move(msg));
}
_g.fillEdgeInfo(false);
_g.fillEdgeInfo(true);

namespace fs = std::filesystem;
auto path = fs::path(std::move(path_));
Expand Down

0 comments on commit 28f6fb0

Please sign in to comment.