Skip to content

Commit

Permalink
Merge pull request #917 from AayushSabharwal/as/initdata-promote
Browse files Browse the repository at this point in the history
fix: call `remake_initialization_data` when explicit `f` provided to `remake`
  • Loading branch information
ChrisRackauckas authored Jan 22, 2025
2 parents 957dd58 + 204987b commit 2eded1b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ function remake(prob::ODEProblem; f = missing,

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
Expand Down Expand Up @@ -413,7 +414,8 @@ function remake(prob::SDEProblem;

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
Expand Down Expand Up @@ -481,7 +483,8 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
Expand Down Expand Up @@ -561,7 +564,8 @@ function remake(prob::SDDEProblem;

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
Expand Down Expand Up @@ -711,7 +715,8 @@ function remake(prob::NonlinearProblem;

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, nothing, p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, nothing, p, newu0, newp)
Expand Down Expand Up @@ -765,7 +770,8 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p

if build_initializeprob
if f !== missing && has_initialization_data(f)
initialization_data = f.initialization_data
initialization_data = remake_initialization_data(
prob.f.sys, f, u0, nothing, p, newu0, newp)
else
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, nothing, p, newu0, newp)
Expand Down
17 changes: 17 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
using OrdinaryDiffEq
using Optimization
using OptimizationOptimJL
using ForwardDiff
using SciMLStructures

probs = []
syss = []
Expand Down Expand Up @@ -406,3 +408,18 @@ end
prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0])
@test_nowarn remake(prob; u0 = [:y => 1.0, :x => nothing])
end

@testset "`initialization_data` u0 and p are promoted with explicit `f`" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p q
@mtkbuild sys = ODESystem([D(x) ~ x, (x - p) ^ 2 + (y - q) ^ 3 ~ 0], t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => 2.0])
@test prob.f.initialization_data !== nothing
buf, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), prob.p)
newps = repack(ForwardDiff.Dual.(buf))
prob2 = @test_nowarn remake(prob; f = prob.f, u0 = ForwardDiff.Dual.(prob.u0), p = newps)
@test prob2.f.initialization_data !== nothing
initprob = prob2.f.initialization_data.initializeprob
@test eltype(initprob.u0) <: ForwardDiff.Dual
@test eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), initprob.p)[1]) <: ForwardDiff.Dual
end

0 comments on commit 2eded1b

Please sign in to comment.