From 5ba496fa7b4c322e2eeb309078b4b12850b9df72 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 27 Oct 2022 17:03:34 -0500 Subject: [PATCH] fix autodiff of remake Fixes https://github.com/SciML/SciMLBase.jl/issues/292 --- src/remake.jl | 19 ++++++++++++------- test/downstream/Project.toml | 2 ++ test/downstream/remake_autodiff.jl | 26 ++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 test/downstream/remake_autodiff.jl diff --git a/src/remake.jl b/src/remake.jl index ea7d9009d..96b970711 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -51,14 +51,19 @@ function remake(prob::ODEProblem; f = missing, if tspan === missing tspan = prob.tspan end - defs = Dict() - if hasproperty(prob.f, :sys) - if hasfield(typeof(prob.f.sys), :ps) - defs = mergedefaults(defs, prob.p, getfield(prob.f.sys, :ps)) - end - if hasfield(typeof(prob.f.sys), :u0) - defs = mergedefaults(defs, prob.u0, getfield(prob.f.sys, :u0)) + + if (p !== missing && eltype(p) <: Pair) || (u0 !== missing && eltype(u0) <: Pair) + defs = Dict{Any,Any}() + if hasproperty(prob.f, :sys) + if hasfield(typeof(prob.f.sys), :ps) + defs = mergedefaults(defs, prob.p, getfield(prob.f.sys, :ps)) + end + if hasfield(typeof(prob.f.sys), :u0) + defs = mergedefaults(defs, prob.u0, getfield(prob.f.sys, :u0)) + end end + else + defs = nothing end if p === missing diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 98d9efee6..0672364bf 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -2,5 +2,7 @@ BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" \ No newline at end of file diff --git a/test/downstream/remake_autodiff.jl b/test/downstream/remake_autodiff.jl new file mode 100644 index 000000000..0e98cad9e --- /dev/null +++ b/test/downstream/remake_autodiff.jl @@ -0,0 +1,26 @@ +using OrdinaryDiffEq, ModelingToolkit, Zygote, SciMLSensitivity + +@variables t +D = Differential(t) +function lotka_volterra(;name=name) + states = @variables x(t)=1.0 y(t)=1.0 + params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0 + eqs = [ + D(x) ~ p1 * x - p2 * x * y, + D(y) ~ -p3 * y + p4 * x * y + ] + return ODESystem(eqs, t, states, params; name = name) +end + +@named lotka_volterra_sys = lotka_volterra() +prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), []) +sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6) + +function sum_of_solution(u0,p) + _prob = remake(prob,u0=u0,p=p) + sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1, sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))) +end + +u0 = [1.0 1.0] +p = [1.5 1. 1. 1.] +du01,dp1 = Zygote.gradient(sum_of_solution,u0,p) \ No newline at end of file