Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better BPINN ode Solver #853

Merged
merged 20 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand All @@ -120,6 +124,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -186,7 +191,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -211,7 +217,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
Expand All @@ -220,7 +227,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
5 changes: 0 additions & 5 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@
# + L2loss2(Tar, θ)
end

# function L2loss2(Tar::PDELogTargetDensity, θ)
# return Tar.full_loglikelihood(setparameters(Tar, θ),
# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
Expand Down Expand Up @@ -361,7 +356,7 @@
# append Ode params to all paramvector - initial_θ
if ninv > 0
# shift ode params(initialise ode params by prior means)
# check if means or user speified is better

Check warning on line 359 in src/PDE_BPINN.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"speified" should be "specified".
initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv])
priors = vcat(priors, param)
nparameters += ninv
Expand Down
89 changes: 76 additions & 13 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
physdt::Float64
extraparams::Int
init_params::I
estim_collocate::Bool

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
Expand All @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
typeof(st),
Expand All @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
end

Expand All @@ -83,7 +86,12 @@ end
vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +
L2loss2(Tar, θ)
else
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand All @@ -92,6 +100,55 @@ function LogDensityProblems.capabilities(::LogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

"""
suggested extra loss function for ODE solver case
"""
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
autodiff = Tar.autodiff
# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
û = Tar.dataset[1]

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(Tar.prob.u0) == 1
physsol = [f(û[i],
ode_params,
t[i])
for i in 1:length(û[:, 1])]
else
physsol = [f([û[i], u1[i]],
ode_params,
t[i])
for i in 1:length(û)]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
return physlogprob
else
return 0
end
end

"""
L2 loss loglikelihood(needed for ODE parameter estimation).
"""
Expand Down Expand Up @@ -247,7 +304,7 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,

vals = nnsol .- physsol

# N dimensional vector if N outputs for NN(each row has logpdf of i[i] where u is vector of dependant variables)
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
Expand Down Expand Up @@ -442,7 +499,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false,
estim_collocate = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand All @@ -467,7 +525,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer neural networks are supported")
error("Only Lux.AbstractExplicitLayer Neural networks are supported")
end

if nchains > Threads.nthreads()
Expand Down Expand Up @@ -500,7 +558,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand All @@ -515,6 +573,9 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, initial_θ))
end

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand Down Expand Up @@ -565,12 +626,14 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end]))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, samples[end]))
end

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
14 changes: 2 additions & 12 deletions test/BPINN_PDEinvsol_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ComponentArrays

Random.seed!(100)

@testset "Example 1: 2D Periodic System with parameter estimation" begin
@testset "Example 1: 1D Periodic System with parameter estimation" begin
# Cos(pi*t) periodic curve
@parameters t, p
@variables u(..)
Expand Down Expand Up @@ -59,17 +59,7 @@ Random.seed!(100)
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])

discretization = BayesianPINN([chainl], QuadratureTraining(), param_estim = true,
dataset = [dataset, nothing])

ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])
# alternative to QuadratureTraining [WIP]

discretization = BayesianPINN([chainl], GridTraining([0.02]), param_estim = true,
dataset = [dataset, nothing])
Expand Down
Loading
Loading