diff --git a/src/dae_default_alg.jl b/src/dae_default_alg.jl index f73ef986..72b1f65c 100644 --- a/src/dae_default_alg.jl +++ b/src/dae_default_alg.jl @@ -1,6 +1,6 @@ function default_algorithm(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tType, - isinplace}; - kwargs...) where {uType, duType, tType, isinplace} + isinplace}; + kwargs...) where {uType, duType, tType, isinplace} o = Dict{Symbol, Any}(kwargs) extra_kwargs = Any[] alg = IDA() # Standard default diff --git a/src/dde_default_alg.jl b/src/dde_default_alg.jl index 4baed71e..5d0006cc 100644 --- a/src/dde_default_alg.jl +++ b/src/dde_default_alg.jl @@ -1,6 +1,6 @@ function default_algorithm(prob::DiffEqBase.AbstractDDEProblem{uType, tType, lType, - isinplace}; - kwargs...) where {uType, tType, lType, isinplace} + isinplace}; + kwargs...) where {uType, tType, lType, isinplace} o = Dict{Symbol, Any}(kwargs) extra_kwargs = Any[] alg = MethodOfSteps(AutoTsit5(Rosenbrock23(autodiff = false))) # Standard default diff --git a/src/default_solve.jl b/src/default_solve.jl index 3e334b5e..7427eecd 100644 --- a/src/default_solve.jl +++ b/src/default_solve.jl @@ -1,6 +1,6 @@ function DiffEqBase.__solve(prob::DiffEqBase.DEProblem, - alg::Union{Nothing, DiffEqBase.DEAlgorithm}, - args...; default_set = false, kwargs...) + alg::Union{Nothing, DiffEqBase.DEAlgorithm}, + args...; default_set = false, kwargs...) if default_set == true error("The chosen algorithm, $alg, does not exist. Please verify that the appropriate solver package has been installed.") end @@ -15,8 +15,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.DEProblem, end function DiffEqBase.__init(prob::DiffEqBase.DEProblem, - alg::Union{Nothing, DiffEqBase.DEAlgorithm}, - args...; default_set = false, kwargs...) + alg::Union{Nothing, DiffEqBase.DEAlgorithm}, + args...; default_set = false, kwargs...) if default_set == true error("The chosen algorithm, $alg, does not exist. Please verify that the appropriate solver package has been installed.") end diff --git a/src/ode_default_alg.jl b/src/ode_default_alg.jl index 9f1cfbc7..99e0c9bb 100644 --- a/src/ode_default_alg.jl +++ b/src/ode_default_alg.jl @@ -1,5 +1,5 @@ function default_algorithm(prob::DiffEqBase.AbstractODEProblem{uType, tType, inplace}; - kwargs...) where {uType, tType, inplace} + kwargs...) where {uType, tType, inplace} o = Dict{Symbol, Any}(kwargs) extra_kwargs = Any[] alg = AutoTsit5(Rosenbrock23(autodiff = false)) # Standard default @@ -75,8 +75,8 @@ function default_algorithm(prob::DiffEqBase.AbstractODEProblem{uType, tType, inp elseif tol_level == :low_tol if length(prob.u0) > 500 alg = AutoVern7(Rodas4(autodiff = false, - linsolve = LinearSolve.KrylovJL_GMRES()), - lazy = !callbacks) + linsolve = LinearSolve.KrylovJL_GMRES()), + lazy = !callbacks) elseif length(prob.u0) > 50 alg = AutoVern7(TRBDF2(autodiff = false), lazy = !callbacks) else diff --git a/src/sde_default_alg.jl b/src/sde_default_alg.jl index 256db911..025f202e 100644 --- a/src/sde_default_alg.jl +++ b/src/sde_default_alg.jl @@ -1,5 +1,5 @@ function default_algorithm(prob::DiffEqBase.AbstractSDEProblem{uType, tType, isinplace, ND}; - kwargs...) where {uType, tType, isinplace, ND} + kwargs...) where {uType, tType, isinplace, ND} o = Dict{Symbol, Any}(kwargs) extra_kwargs = Any[] alg = SOSRI() # Standard default diff --git a/test/default_ode_alg_test.jl b/test/default_ode_alg_test.jl index 80b92487..dbc7a121 100644 --- a/test/default_ode_alg_test.jl +++ b/test/default_ode_alg_test.jl @@ -3,7 +3,7 @@ using DifferentialEquations, Test f_2dlinear = (du, u, p, t) -> (@. du = p * u) f_2dlinear_analytic = (u0, p, t) -> @. u0 * exp(p * t) prob_ode_2Dlinear = ODEProblem(ODEFunction(f_2dlinear, analytic = f_2dlinear_analytic), - rand(4, 2), (0.0, 1.0), 1.01) + rand(4, 2), (0.0, 1.0), 1.01) alg, kwargs = default_algorithm(prob_ode_2Dlinear; dt = 1 // 2^(4)) integ = init(prob_ode_2Dlinear; dt = 1 // 2^(4)) @@ -36,12 +36,14 @@ sol = solve(prob_ode_2Dlinear; alg_hints = [:stiff], reltol = 1e-1) @test typeof(sol.alg) <: Rosenbrock23 const linear_bigα = parse(BigFloat, "1.01") -f = (du, u, p, t) -> begin for i in 1:length(u) - du[i] = linear_bigα * u[i] -end end +f = (du, u, p, t) -> begin + for i in 1:length(u) + du[i] = linear_bigα * u[i] + end +end (::typeof(f))(::Type{Val{:analytic}}, u0, p, t) = u0 * exp(linear_bigα * t) prob_ode_bigfloat2Dlinear = ODEProblem(f, map(BigFloat, rand(4, 2)) .* ones(4, 2) / 2, - (0.0, 1.0)) + (0.0, 1.0)) sol = solve(prob_ode_bigfloat2Dlinear; dt = 1 // 2^(4)) @test typeof(sol.alg.algs[1]) <: Vern9 @@ -60,12 +62,12 @@ sol = solve(prob_ode_bigfloat2Dlinear, nothing; alg_hints = [:stiff]) struct FooAlg end @test_throws DiffEqBase.NonSolverError solve(prob_ode_bigfloat2Dlinear, FooAlg(); - default_set = true) + default_set = true) struct FooAlg2 <: DiffEqBase.DEAlgorithm end @test_throws DiffEqBase.ProblemSolverPairingError solve(prob_ode_bigfloat2Dlinear, - FooAlg2(); default_set = true) + FooAlg2(); default_set = true) prob = ODEProblem(f, rand(4, 2) .* ones(4, 2) / 2, (0.0, 1.0)) diff --git a/test/runtests.jl b/test/runtests.jl index a3c19f51..3953b8b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,28 @@ using DifferentialEquations, Test, SafeTestsets @time begin - @time @safetestset "Default Discrete Algorithm" begin include("default_discrete_alg_test.jl") end - @time @safetestset "Default ODE Algorithm" begin include("default_ode_alg_test.jl") end - @time @safetestset "Default Steady State Algorithm" begin include("default_steady_state_alg_test.jl") end - @time @safetestset "Default SDE Algorithm" begin include("default_sde_alg_test.jl") end - @time @safetestset "Default RODE Algorithm" begin include("default_rode_alg_test.jl") end - @time @safetestset "Default DDE Algorithm" begin include("default_dde_alg_test.jl") end - @time @safetestset "Default DAE Algorithm" begin include("default_dae_alg_test.jl") end - @time @safetestset "Default BVP Algorithm" begin include("default_bvp_alg_test.jl") end + @time @safetestset "Default Discrete Algorithm" begin + include("default_discrete_alg_test.jl") + end + @time @safetestset "Default ODE Algorithm" begin + include("default_ode_alg_test.jl") + end + @time @safetestset "Default Steady State Algorithm" begin + include("default_steady_state_alg_test.jl") + end + @time @safetestset "Default SDE Algorithm" begin + include("default_sde_alg_test.jl") + end + @time @safetestset "Default RODE Algorithm" begin + include("default_rode_alg_test.jl") + end + @time @safetestset "Default DDE Algorithm" begin + include("default_dde_alg_test.jl") + end + @time @safetestset "Default DAE Algorithm" begin + include("default_dae_alg_test.jl") + end + @time @safetestset "Default BVP Algorithm" begin + include("default_bvp_alg_test.jl") + end end