diff --git a/src/load.jl b/src/load.jl index 71ed41b..c0b401e 100644 --- a/src/load.jl +++ b/src/load.jl @@ -55,6 +55,11 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::At return push_call!(tape, _cos, args[1]) end +function load_node!(tape::Tape, ::OpConfig{:ONNX, :Abs}, args::VarVec, attrs::AttrDict) + return push_call!(tape, _abs, args[1]) +end + + function load_node!(tape::Tape, nd::NodeProto, backend::Symbol) args = [tape.c.name2var[name] for name in nd.input] attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute)) diff --git a/src/ops.jl b/src/ops.jl index a45b1fb..7ee8620 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -49,6 +49,7 @@ add(xs...) = .+(xs...) sub(xs...) = .-(xs...) _sin(x) = sin.(x) _cos(x) = cos.(x) +_abs(x) = abs.(x) mul(xs...) = .*(xs...) relu(x) = NNlib.relu.(x) leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a) diff --git a/src/save.jl b/src/save.jl index 9228738..a52342d 100644 --- a/src/save.jl +++ b/src/save.jl @@ -121,6 +121,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_cos)}, op::Umlaut.C push!(g.node, nd) end +function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_abs)}, op::Umlaut.Call) + nd = NodeProto("Abs", op) + push!(g.node, nd) +end + function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call) nd = NodeProto( input=[onnx_name(v) for v in reverse(op.args)], diff --git a/test/saveload.jl b/test/saveload.jl index 1278c90..82422c3 100644 --- a/test/saveload.jl +++ b/test/saveload.jl @@ -31,6 +31,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name ort_test(ONNX._cos, A) end + @testset "Abs" begin + A = rand(3, 4) + ort_test(ONNX._abs, A) + end + @testset "Gemm" begin A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3)) ort_test(ONNX.onnx_gemm, A, B')