Skip to content

Commit

Permalink
THE PDE SOLVER WORKS COMPLETELY
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Nov 17, 2023
1 parent 805380c commit a2a2292
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 293 deletions.
222 changes: 92 additions & 130 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ mutable struct PDELogTargetDensity{
Phi::PH

function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, names, physdt, extraparams,
init_params::AbstractVector, full_loglikelihood, Phi)
priors, allstd, names, physdt, extraparams,
init_params::AbstractVector, full_loglikelihood, Phi)
new{
typeof(strategy),
typeof(dataset),
Expand All @@ -41,8 +41,8 @@ mutable struct PDELogTargetDensity{
Phi)
end
function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, names, physdt, extraparams,
init_params::NamedTuple, full_loglikelihood, Phi)
priors, allstd, names, physdt, extraparams,
init_params::NamedTuple, full_loglikelihood, Phi)
new{
typeof(strategy),
typeof(dataset),
Expand All @@ -65,76 +65,57 @@ mutable struct PDELogTargetDensity{
end

function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# forward solving
# Tar.full_loglikelihood(vector_to_parameters(θ, Tar.init_params), Tar.allstd)
# println("1 : ",
# length(Tar.full_loglikelihood(vector_to_parameters(θ,
# Tar.init_params),
# Tar.allstd).partials))
# println("2 : ", L2LossData(Tar, θ).value)
# println("2 : ", L2LossData(Tar, θ).partials)

# # println("3 : ", length(priorlogpdf(Tar, θ).partials))

# # println("sum : ",
# # (Tar.full_loglikelihood(vcat(vector_to_parameters(θ[1:(end - 1)],
# # Tar.init_params[1:(end - 1)]), θ[end]),
# # Tar.allstd) +
# # L2LossData(Tar, θ) + priorlogpdf(Tar, θ)).value)
# println(typeof(θ) <: AbstractVector)
# println(length(θ))

# println("1 : ",
# length(Tar.full_loglikelihood(vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params[1:(end - Tar.extraparams)]), θ[(end - Tar.extraparams + 1):end]),
# Tar.allstd).partials))
# println("2 : ", length(L2LossData(Tar, θ).partials))
# println("3 : ", length(priorlogpdf(Tar, θ).partials))

# println(length(initial_nnθ))
# println(length(pinnrep.flat_init_params))
# println(initial_nnθ)
# println(pinnrep.flat_init_params)
# println(typeof(θ) <: AbstractVector)
# println(length(θ))
# println(typeof(θ[1:(end - Tar.extraparams)]) <: AbstractVector)
# println(length(θ[1:(end - Tar.extraparams)]))
# println(length(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params[1:(end - Tar.extraparams)])))

# Tar.full_loglikelihood(vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params), θ[(end - Tar.extraparams + 1):end]),
# Tar.allstd)

# θ = reduce(vcat, θ)
# yuh = vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params),
# adapt(typeof(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params)), θ[(end - Tar.extraparams + 1):end]))

# yuh = ComponentArrays.ComponentArray(;
# # u = vector_to_parameters(θ[1:(end - Tar.extraparams)], Tar.init_params),
# depvar = vector_to_parameters(θ[1:(end - Tar.extraparams)], Tar.init_params),
# p = θ[(end - Tar.extraparams + 1):end])

return Tar.full_loglikelihood(setLuxparameters(Tar, θ),
Tar.allstd) + priorlogpdf(Tar, θ)
# +L2LossData(Tar, θ)
# for parameter estimation neccesarry to use multioutput case
return Tar.full_loglikelihood(setparameters(Tar, θ),
Tar.allstd) + priorlogpdf(Tar, θ) + L2LossData(Tar, θ)
# + L2loss2(Tar, θ)
end

function setLuxparameters(Tar::PDELogTargetDensity, θ)
a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in [
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params),
]))
function L2loss2(Tar::PDELogTargetDensity, θ)
return Tar.full_loglikelihood(setparameters(Tar, θ),
Tar.allstd)
end
function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
ps = Tar.init_params

if (ps[names[1]] isa ComponentArrays.ComponentVector)
# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors
# which we use for mapping current ahmc sampled vector of parameters onto NNs
i = 0
Luxparams = []
for x in names
endind = length(ps[x])
push!(Luxparams, vector_to_parameters(ps_new[(i + 1):(i + endind)], ps[x]))
i += endind
end
Luxparams
else
# multioutput Flux
Luxparams = θ
end

if (Luxparams isa AbstractVector) && (Luxparams[1] isa ComponentArrays.ComponentVector)
# multioutput Lux
a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))

b = θ[(end - Tar.extraparams + 1):end]
if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]

ComponentArrays.ComponentArray(;
depvar = a,
p = b)
return ComponentArrays.ComponentArray(;
depvar = a,
p = b)
else
return ComponentArrays.ComponentArray(;
depvar = a)
end
else
# multioutput Lux case
return vector_to_parameters(Luxparams, ps)
end
end

LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim

function LogDensityProblems.capabilities(::PDELogTargetDensity)
Expand All @@ -146,28 +127,24 @@ function L2loss2(Tar::PDELogTargetDensity, θ)
end
# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::PDELogTargetDensity, θ)
return logpdf(MvNormal(Tar.Phi[1](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params))[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][1]), zeros(length(Tar.dataset[end])))
# matrix(each row corresponds to vector u's rows)
# if Tar.dataset isa Vector{Nothing} || Tar.extraparams == 0
# return 0
# else
# nn = [phi(Tar.dataset[end]', θ[1:(length(θ) - Tar.extraparams)])
# for phi in Tar.Phi]

# L2logprob = 0
# for i in 1:(length(Tar.dataset) - 1)
# # for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
# L2logprob += logpdf(MvNormal(nn[i][:],
# ones(length(Tar.dataset[end])) .* Tar.allstd[3]),
# Tar.dataset[i])
# end

# return L2logprob
# end
return 0
if Tar.extraparams > 0
if Tar.init_params isa ComponentArrays.ComponentVector
return sum([logpdf(MvNormal(Tar.Phi[i](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params)[Tar.names[i]])[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][i]), Tar.dataset[i])
for i in eachindex(Tar.Phi)])
else
# Flux case needs subindexing wrt Tar.names indices(hence stored in Tar.names)
return sum([logpdf(MvNormal(Tar.Phi[i](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params)[Tar.names[2][i]])[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][i]), Tar.dataset[i])
for i in eachindex(Tar.Phi)])
end
else
return 0
end
end

# priors for NN parameters + ODE constants
Expand Down Expand Up @@ -230,16 +207,16 @@ end
# priors: pdf for W,b + pdf for ODE params
# lotka specific kwargs here
function ahmc_bayesian_pinn_pde(pde_system, discretization;
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization, bayesian = true)

# for physics loglikelihood
Expand All @@ -252,37 +229,30 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# for new L2 loss
# discretization.additional_loss =

# converting vector of parameters to ComponentArray for runtimegenerated functions
names = ntuple(i -> pinnrep.depvars[i], length(discretization.chain))

if nchains > Threads.nthreads()
throw(error("number of chains is greater than available threads"))
elseif nchains < 1
throw(error("number of chains must be greater than 1"))
end

# remove inv params take only NN params, AHMC uses Float64
initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))]
if discretization.multioutput
if chain[1] isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# namedtuple form of Lux params required for RuntimeGeneratedFunctions
initial_nnθ, st = Lux.setup(Random.default_rng(), chain[1])
else
# remove inv params take only NN params
initial_θ = collect(Float64, initial_nnθ)
end
initial_θ = collect(Float64, initial_nnθ)
initial_nnθ = pinnrep.init_params

if (discretization.multioutput && chain[1] isa Lux.AbstractExplicitLayer)
# converting vector of parameters to ComponentArray for runtimegenerated functions
names = ntuple(i -> pinnrep.depvars[i], length(chain))
else
if chain isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# namedtuple form of Lux params required for RuntimeGeneratedFunctions
initial_nnθ, st = Lux.setup(Random.default_rng(), chain)
else
# remove inv params take only NN params
initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))]
initial_θ = collect(Float64, initial_nnθ)
# this case is for Flux multioutput
i = 0
temp = []
for j in eachindex(initial_nnθ)
len = length(initial_nnθ[j])
push!(temp, (i + 1):(i + len))
i += len
end
names = tuple(1, temp)
end

#ode parameter estimation
Expand Down Expand Up @@ -314,10 +284,6 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
full_weighted_loglikelihood,
Phi)

println(ℓπ.full_loglikelihood(setLuxparameters(ℓπ, initial_θ), ℓπ.allstd))
println(priorlogpdf(ℓπ, initial_θ))
println(L2LossData(ℓπ, initial_θ))

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]

Expand Down Expand Up @@ -364,10 +330,6 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

println(ℓπ.full_loglikelihood(setLuxparameters(ℓπ, samples[end]),
ℓπ.allstd))
println(priorlogpdf(ℓπ, samples[end]))
println(L2LossData(ℓπ, samples[end]))
# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
Expand Down
36 changes: 22 additions & 14 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,23 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
end

"""
Cool function needed for converting vector of sampled parameters into namedTuples in case of Lux chain output, derivatives
Cool function needed for converting vector of sampled parameters into ComponentVector in case of Lux chain output, derivatives
the sampled parameters are of exotic type `Dual` due to ForwardDiff's autodiff tagging
"""
function vector_to_parameters(ps_new::AbstractVector,
ps::Union{NamedTuple, <:AbstractVector})
if typeof(ps) <: AbstractVector
ps::Union{ComponentArrays.ComponentVector, AbstractVector})
if ps isa ComponentArrays.ComponentVector
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return Functors.fmap(get_ps, ps)
else
return ps_new
end
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return Functors.fmap(get_ps, ps)
end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
Expand Down Expand Up @@ -559,9 +560,10 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
end
end

println(physloglikelihood(ℓπ, initial_θ))
println(priorweights(ℓπ, initial_θ))
# println(L2LossData(ℓπ, initial_nnθ))
println("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ))
println("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
println("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
println("Current custom loss Log-likelihood : ", L2loss2(ℓπ, initial_θ))

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand Down Expand Up @@ -609,6 +611,12 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

println("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
println("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
println("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
println("Current custom loss Log-likelihood : ", L2loss2(ℓπ, samples[end]))

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
Expand Down
Loading

0 comments on commit a2a2292

Please sign in to comment.