From 7cfb9f6b82d9e33277caf7b93dbbe0d3add6e0fa Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 4 Mar 2024 21:23:40 -0600 Subject: [PATCH] add complex sqrt --- src/compiler.jl | 1 + src/compiler/interpreter.jl | 2 +- test/runtests.jl | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 03a413c8805..4a71399d7c4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -70,6 +70,7 @@ include("compiler/utils.jl") const cmplx_known_ops = Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.inv) => (:cmplx_inv, 1, nothing), + typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), ) const known_ops = Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index a2900b3356f..5885679be57 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -94,7 +94,7 @@ function is_primitive_func(@nospecialize(TT)) end end - if ft == typeof(Base.inv) + if ft == typeof(Base.inv) || ft == typeof(Base.sqrt) if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}} return true end diff --git a/test/runtests.jl b/test/runtests.jl index 99aa3b208fe..f6100bab815 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -277,6 +277,7 @@ make3() = (1.0, 2.0, 3.0) test_scalar(x->rem(x, 1), 0.7) test_scalar(x->rem2pi(x,RoundDown), 0.7) test_scalar(x->fma(x,x+1,x/3), 2.3) + test_scalar(sqrt, 1.7+2.1im) @test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] ≈ cos(1.0)