diff --git a/src/load.jl b/src/load.jl index 3985972..71ed41b 100644 --- a/src/load.jl +++ b/src/load.jl @@ -51,6 +51,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::At return push_call!(tape, _sin, args[1]) end +function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::AttrDict) + return push_call!(tape, _cos, args[1]) +end + function load_node!(tape::Tape, nd::NodeProto, backend::Symbol) args = [tape.c.name2var[name] for name in nd.input] attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute)) diff --git a/src/ops.jl b/src/ops.jl index 67403e9..a45b1fb 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -48,6 +48,7 @@ end add(xs...) = .+(xs...) sub(xs...) = .-(xs...) _sin(x) = sin.(x) +_cos(x) = cos.(x) mul(xs...) = .*(xs...) relu(x) = NNlib.relu.(x) leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a) diff --git a/src/save.jl b/src/save.jl index 5205da4..9228738 100644 --- a/src/save.jl +++ b/src/save.jl @@ -116,6 +116,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_sin)}, op::Umlaut.C push!(g.node, nd) end +function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_cos)}, op::Umlaut.Call) + nd = NodeProto("Cos", op) + push!(g.node, nd) +end + function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call) nd = NodeProto( input=[onnx_name(v) for v in reverse(op.args)], diff --git a/test/saveload.jl b/test/saveload.jl index 3d53276..1278c90 100644 --- a/test/saveload.jl +++ b/test/saveload.jl @@ -25,6 +25,12 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name ort_test(ONNX._sin, A) end + @testset "Cos" begin + # ONNXRunTime has no implementation for Cos(x::Float64), using Float32 + A = rand(Float32, 3, 4) + ort_test(ONNX._cos, A) + end + @testset "Gemm" begin A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3)) ort_test(ONNX.onnx_gemm, A, B')