Skip to content

Commit

Permalink
Split up the 2Point BVP
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 6, 2023
1 parent ea702db commit 90bcabc
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.1.0"
version = "2.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
37 changes: 21 additions & 16 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ time points, and for shooting type methods `u=sol` the ODE solution.
Note that all features of the `ODESolution` are present in this form.
In both cases, the size of the residual matches the size of the initial condition.
If the bvp is a TwoPointBVProblem it must define either of the following functions
If the bvp is a TwoPointBVProblem then `bc` must be a Tuple `(bca, bcb)` and each of them
must define either of the following functions:
```julia
bc!((resid_a, resid_b), (u_a, u_b), p)
resid_a, resid_b = bc((u_a, u_b), p)
begin
bca!(resid_a, u_a, p)
bcb!(resid_b, u_b, p)
end
begin
resid_a = bca(u_a, p)
resid_b = bcb(u_b, p)
end
```
where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
Expand All @@ -98,17 +105,16 @@ every solve call.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
Expand All @@ -119,25 +125,24 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type,
kwargs)
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, u0, _tspan, p, problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
end

function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
return BVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...)
end

# This is mostly a fake stuct and isn't used anywhere
Expand All @@ -163,13 +168,13 @@ function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
end
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
end
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
end

# Allow previous timeseries solution
Expand Down
2 changes: 1 addition & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan
twopoint = problem_type isa TwoPointBVProblem

if bc === missing
bc = prob.bc
bc = prob.f.bc
end

if f === missing
Expand Down
33 changes: 30 additions & 3 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4004,9 +4004,31 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
_bccolorvec = bccolorvec
end

bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
bciip = if !twopoint
isinplace(bc, 4, "bc", iip)
else
@assert length(bc) == 2
bc = Tuple(bc)
if isinplace(first(bc), 3, "bc", iip) != isinplace(last(bc), 3, "bc", iip)
throw(NonconformingFunctionsError(["bc[1]", "bc[2]"]))
end
isinplace(first(bc), 3, "bc", iip)
end
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
bcjaciip = if bcjac !== nothing
if twopoint
isinplace(bcjac, 4, "bcjac", bciip)
else
@assert length(bcjac) == 2
bcjac = Tuple(bcjac)
if isinplace(first(bcjac), 3, "bcjac", bciip) != isinplace(last(bcjac), 3, "bcjac", bciip)
throw(NonconformingFunctionsError(["bcjac[1]", "bcjac[2]"]))
end
isinplace(bcjac, 3, "bcjac", iip)
end
else
bciip
end
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip
vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip
Expand All @@ -4029,8 +4051,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
end
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
bcresid_prototype = ArrayPartition(first(bcresid_prototype),
last(bcresid_prototype))
end

bccolorvec !== nothing && length(bccolorvec) == 2 && (bccolorvec = Tuple(bccolorvec))

bcjac_prototype !== nothing && length(bcjac_prototype) == 2 && (bcjac_prototype = Tuple(bcjac_prototype))
end

if any(bc_nonconforming)
Expand Down

0 comments on commit 90bcabc

Please sign in to comment.