Skip to content

Commit

Permalink
fix autodiff of remake (#293)
Browse files Browse the repository at this point in the history
* fix autodiff of remake

Fixes #292

* format
  • Loading branch information
ChrisRackauckas authored Oct 30, 2022
1 parent 0057b23 commit 82ff444
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,19 @@ function remake(prob::ODEProblem; f = missing,
if tspan === missing
tspan = prob.tspan
end
defs = Dict()
if hasproperty(prob.f, :sys)
if hasfield(typeof(prob.f.sys), :ps)
defs = mergedefaults(defs, prob.p, getfield(prob.f.sys, :ps))
end
if hasfield(typeof(prob.f.sys), :u0)
defs = mergedefaults(defs, prob.u0, getfield(prob.f.sys, :u0))

if (p !== missing && eltype(p) <: Pair) || (u0 !== missing && eltype(u0) <: Pair)
defs = Dict{Any, Any}()
if hasproperty(prob.f, :sys)
if hasfield(typeof(prob.f.sys), :ps)
defs = mergedefaults(defs, prob.p, getfield(prob.f.sys, :ps))
end
if hasfield(typeof(prob.f.sys), :u0)
defs = mergedefaults(defs, prob.u0, getfield(prob.f.sys, :u0))
end
end
else
defs = nothing
end

if p === missing
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
27 changes: 27 additions & 0 deletions test/downstream/remake_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using OrdinaryDiffEq, ModelingToolkit, Zygote, SciMLSensitivity

@variables t
D = Differential(t)
function lotka_volterra(; name = name)
states = @variables x(t)=1.0 y(t)=1.0
params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0
eqs = [
D(x) ~ p1 * x - p2 * x * y,
D(y) ~ -p3 * y + p4 * x * y,
]
return ODESystem(eqs, t, states, params; name = name)
end

@named lotka_volterra_sys = lotka_volterra()
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)

function sum_of_solution(u0, p)
_prob = remake(prob, u0 = u0, p = p)
sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())))
end

u0 = [1.0 1.0]
p = [1.5 1.0 1.0 1.0]
du01, dp1 = Zygote.gradient(sum_of_solution, u0, p)

0 comments on commit 82ff444

Please sign in to comment.