diff --git a/scripts/onnx/to_onnx.py b/scripts/onnx/to_onnx.py index b4eb26e3..d48b4c22 100644 --- a/scripts/onnx/to_onnx.py +++ b/scripts/onnx/to_onnx.py @@ -1,4 +1,5 @@ import mmap +import re import argparse from onnx import TensorProto, NodeProto, save_model from onnx.helper import ( @@ -10,18 +11,23 @@ make_opsetid, ) from onnx.checker import check_model + + class Topo: def __init__(self, bytes: bytes): list = bytes.strip().split(b"<-") self.inputs = [int(s.strip(b"%")) for s in list[1].split()] self.outputs = [int(s.strip(b"%")) for s in list[0].split()] + def __str__(self) -> str: return f"{self.inputs} <- {self.outputs}" - + + class Tensor: def __init__(self, bytes_: bytes): list = bytes_.split(b"\t") self.name = str(list[1].strip(), "utf-8") + def map_dt(dt: bytes) -> TensorProto.DataType: match dt: case b"F32": @@ -58,6 +64,7 @@ def map_dt(dt: bytes) -> TensorProto.DataType: return TensorProto.BFLOAT16 case _: return TensorProto.UNDEFINED + self.dt = map_dt(list[2].strip()) layout = list[3].strip() if layout != b"NCHW" and layout != b"ELSE": @@ -66,9 +73,11 @@ def map_dt(dt: bytes) -> TensorProto.DataType: self.offset = int(range[0], 0) self.size = int(range[1], 0) self.shape = [int(s) for s in split_array(list[5])] + def __str__(self) -> str: return f"{self.name} (dt = {self.dt}) {self.shape} {self.offset}..{self.offset + self.size}" - + + class Operator: def __init__(self, bytes: bytes): list = bytes.split(b"\t") @@ -78,9 +87,12 @@ def __init__(self, bytes: bytes): list = list[1].rsplit(b")", 1) self.meta = list[0].strip() self.topo = Topo(list[1]) + def __str__(self) -> str: return f"{self.type}: {self.name}, meta = {self.meta}, topo = {self.topo}" + def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: + DEFAULT_DOMAIN = "" if self.type == "BatchNormalization": return ( make_node( @@ -88,6 +100,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, epsilon=float(self.meta.split(b"=")[0]), ), [], @@ -101,6 +114,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, dilations=meta[0:rank], strides=meta[rank : 2 * rank], pads=meta[2 * rank : 4 * rank], @@ -114,6 +128,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ), [], ) @@ -131,6 +146,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ceil_mode=ceil_mode, kernel_shape=kernel_shape, dilations=meta[0:rank], @@ -139,13 +155,25 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ), [], ) - if self.type == "Add": + if self.type in [ + "Add", + "Pow", + "Sqrt", + "Div", + "Mul", + "Sub", + "Exp", + "Log", + "Neg", + "Sigmoid", + ]: return ( make_node( self.type, [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ), [], ) @@ -156,6 +184,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ), [], ) @@ -172,6 +201,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, alpha=alpha, beta=beta, transA=transA, @@ -186,6 +216,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.inputs], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ), [], ) @@ -200,11 +231,94 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[self.topo.inputs[0]].name, shape_name], [tensors[i].name for i in self.topo.outputs], self.name, + domain=DEFAULT_DOMAIN, ), [shape], ) + if self.type in ["Gather", "Concat", "Softmax"]: + meta = self.meta.split(b"/") + axis = int(meta[0]) + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[self.topo.outputs[0]].name], + self.name, + domain=DEFAULT_DOMAIN, + axis=axis, + ), + [], + ) + if self.type == "ReduceMean": + meta = self.meta.split(b",") + keepDims = meta[2] == b"true" + axes = [int(x) for x in split_array(meta[0])] + return ( + make_node( + self.type, + [tensors[self.topo.inputs[0]].name], + [tensors[self.topo.outputs[0]].name], + self.name, + domain=DEFAULT_DOMAIN, + axes=axes, + keepdims=keepDims, + ), + [], + ) + if self.type == "Transpose": + meta = [int(x) for x in split_array(self.meta)] + return ( + make_node( + self.type, + [tensors[self.topo.inputs[0]].name], + [tensors[self.topo.outputs[0]].name], + self.name, + domain=DEFAULT_DOMAIN, + perm=meta, + ), + [], + ) + if self.type == "Slice": + # starts, ends, axes, steps = split_array_slice(self.meta) + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[self.topo.outputs[0]].name], + self.name, + domain=DEFAULT_DOMAIN, + ), + [], + ) + if self.type == "Cast": + to = int(tensors[self.topo.outputs[0]].dt) + return ( + make_node( + self.type, + [tensors[self.topo.inputs[0]].name], + [tensors[self.topo.outputs[0]].name], + self.name, + domain=DEFAULT_DOMAIN, + to=to, + ), + [], + ) + if self.type == "RmsNormalization": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + domain="refactor", + epsilon=1e-5, + ), + [], + ) + raise ValueError(f"Unsupported operator {self.type}") + def parse_args(): parser = argparse.ArgumentParser(description="Analysis serialize file.") parser.add_argument( @@ -214,13 +328,23 @@ def parse_args(): help="Path to save the serialize output files.", ) args = parser.parse_args() - return ( - args.input - ) + return args.input + def split_array(arr: bytes): return (x for x in arr.strip().strip(b"[").strip(b"]").split()) + +def split_array_slice(arr: bytes): + meta_array = split_array(arr) + meta = [list(map(int, re.findall(r"\d+", x))) for x in meta_array] + starts = [int(x[0]) for x in meta] + ends = [int(x[0] + x[1] * x[2]) for x in meta] + axes = [x for x in range(len(meta))] + steps = [int(x[2]) for x in meta] + return starts, ends, axes, steps + + def main(): path = parse_args() info_path = path + "/graph.info" @@ -268,10 +392,24 @@ def main(): ], initializer, ) - model = make_model(graph, opset_imports=[make_opsetid( - domain="", version=13)]) - check_model(model) - save_model(model, outputfile) + # model = make_model( + # graph, opset_imports=[make_opsetid(domain="", version=13)] + # ) + model = make_model( + graph, + opset_imports=[ + make_opsetid(domain="refactor", version=1), + make_opsetid(domain="", version=13), + ], + ) + save_model( + model, + outputfile, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + check_model(outputfile) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index 92ec95b7..28295719 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -220,8 +220,8 @@ namespace refactor::computation { void Graph::optimize() { auto graphMutant = GraphMutant(*this); std::vector passes = { - "MatMulTransposeFuse", - "ConvToMatmul", + // "MatMulTransposeFuse", + // "ConvToMatmul", }; register_();//all pass insert auto g = std::make_shared(graphMutant); diff --git a/src/05computation/src/operators/reduce.cc b/src/05computation/src/operators/reduce.cc index a3edcd21..f816cdd4 100644 --- a/src/05computation/src/operators/reduce.cc +++ b/src/05computation/src/operators/reduce.cc @@ -116,7 +116,7 @@ namespace refactor::computation { return std::make_unique(target, type, axes); } auto Op::serialize() const noexcept -> std::string { - return fmt::format("{}({}/{}, {})", + return fmt::format("{}({}, {}, {})", name(), vec2str(axes), rank, diff --git a/src/07onnx/src/operators/cast.cc b/src/07onnx/src/operators/cast.cc index ac96c04d..dc6e9b0a 100644 --- a/src/07onnx/src/operators/cast.cc +++ b/src/07onnx/src/operators/cast.cc @@ -1,6 +1,7 @@ #include "computation/operators/cast.h" #include "cast.hh" #include "common.h" +#include "computation/operators/identity.h" #include namespace refactor::onnx { @@ -30,7 +31,6 @@ namespace refactor::onnx { auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { EXPECT_SIZE(1) - auto const &input = inputs[0]; auto ans = Tensor::share(to, input.shape, extractDependency(inputs)); if (!options.shouldCalculate(inputs, {*ans})) { @@ -116,8 +116,13 @@ namespace refactor::onnx { } return Ok(Tensors{std::move(ans)}); } - auto Op::lower(TensorRefs) const -> computation::OpBox { + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { using Op_ = computation::Cast; + auto const &input = inputs[0]; + auto from = input.dataType; + if (from == to) { + return std::make_unique(); + } return std::make_unique(); } diff --git a/src/07onnx/src/operators/mat_mul.cc b/src/07onnx/src/operators/mat_mul.cc index 7eb26376..9a4be75d 100644 --- a/src/07onnx/src/operators/mat_mul.cc +++ b/src/07onnx/src/operators/mat_mul.cc @@ -63,7 +63,7 @@ namespace refactor::onnx { auto Op::lower(TensorRefs) const -> computation::OpBox { using Op_ = computation::MatMul; - return std::make_unique(1.0, 1.0, false, false); + return std::make_unique(1.0, 0.0, false, false); } }// namespace refactor::onnx diff --git a/src/09python_ffi/src/compiler.cc b/src/09python_ffi/src/compiler.cc index bf04053e..f6d2e9da 100644 --- a/src/09python_ffi/src/compiler.cc +++ b/src/09python_ffi/src/compiler.cc @@ -155,7 +155,7 @@ namespace refactor::python_ffi { msg += ']'; RUNTIME_ERROR(std::move(msg)); } - _g.fillEdgeInfo(false); + _g.fillEdgeInfo(true); namespace fs = std::filesystem; auto path = fs::path(std::move(path_));