Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 10, 2024
1 parent 205aef1 commit d3429ec
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Reexport = "1.2"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.1"
SciMLBase = "2.56"
Statistics = "1.11.0"
Statistics = "1.10.0, 1.11.0"
SymbolicUtils = "1.5, 2, 3"
Symbolics = "5.27.1, 6"
Test = "1.10"
Expand Down
13 changes: 6 additions & 7 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,8 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
dict_var_span_ = Dict([Symbol(d.variables) => bc for (d, bc) in zip(domains, bc_data)])

bcs_train_sets = map(bound_args) do bt
span = map(b -> get(dict_var_span, b, b), bt)
_set = hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)
x = convert.(eltypeθ, adapt(eltypeθ, _set))
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
end

pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
Expand All @@ -230,8 +229,8 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D

pde_train_sets = map(pde_args) do bt
span = map(b -> get(dict_var_span_, b, b), bt)
_set = hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)
x = convert.(eltypeθ, adapt(eltypeθ, _set))
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
end
[pde_train_sets, bcs_train_sets]
end
Expand Down Expand Up @@ -266,11 +265,11 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars,

pde_lower_bounds = map(pde_args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
map(s -> convert(eltypeθ, adapt(eltypeθ, s)) + cbrt(eps(eltypeθ)), span)
map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span)
end
pde_upper_bounds = map(pde_args) do pd
span = map(p -> get(dict_upper_bound, p, p), pd)
map(s -> convert(eltypeθ, adapt(eltypeθ, s)) - cbrt(eps(eltypeθ)), span)
map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span)
end
pde_bounds = [pde_lower_bounds, pde_upper_bounds]

Expand Down
10 changes: 5 additions & 5 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ function generate_training_sets(domains, dx, eqs, eltypeθ)
dxs = fill(dx, length(domains))
end
spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]
train_set = hcat(vec(map(points -> collect(points), Iterators.product(spans...)))...)
x = convert.(eltypeθ, adapt(eltypeθ, train_set))
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
end

function get_loss_function_(loss, init_params, pde_system, strategy::GridTraining)
Expand All @@ -30,7 +30,7 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strateg

bounds = first(map(args) do pd
span = map(p -> get(dict_span, p, p), pd)
map(s -> covert(eltypeθ, adapt(eltypeθ, s)), span)
map(s -> adapt(eltypeθ, s), span)
end)
bounds = [getindex.(bounds, 1), getindex.(bounds, 2)]
return bounds
Expand Down Expand Up @@ -75,11 +75,11 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars,

lower_bounds = map(args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
map(s -> convert(eltypeθ, adapt(eltypeθ, s)), span)
map(s -> adapt(eltypeθ, s), span)
end
upper_bounds = map(args) do pd
span = map(p -> get(dict_upper_bound, p, p), pd)
map(s -> convert(eltypeθ, adapt(eltypeθ, s)), span)
map(s -> adapt(eltypeθ, s), span)
end
bound = lower_bounds, upper_bounds
end
Expand Down
8 changes: 4 additions & 4 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end

function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractLuxLayer, T, U <: Number}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
eltypeθ, typeθ = eltype.depvar), parameterless_type(ComponentArrays.getdata.depvar))
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -128,15 +128,15 @@ end
function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractLuxLayer, T, U <: Number}
# Batch via data as row vectors
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
eltypeθ, typeθ = eltype.depvar), parameterless_type(ComponentArrays.getdata.depvar))
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractLuxLayer, T, U}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
eltypeθ, typeθ = eltype.depvar), parameterless_type(ComponentArrays.getdata.depvar))
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -146,7 +146,7 @@ end
function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractLuxLayer, T, U}
# Batch via data as row vectors
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
eltypeθ, typeθ = eltype.depvar), parameterless_type(ComponentArrays.getdata.depvar))
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand Down
6 changes: 3 additions & 3 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,13 @@ end

# the method to calculate the derivative
function numeric_derivative(phi, u, x, εs, order, θ)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
_type = parameterless_type(ComponentArrays.getdata(θ))

ε = εs[order]
_epsilon = inv(first(ε[ε .!= zero(ε)]))

ε = convert.(eltypeθ, adapt(typeθ, ε))
x = convert.(eltypeθ, adapt(typeθ, x))
ε = adapt(_type, ε)
x = adapt(_type, x)

# any(x->x!=εs[1],εs)
# εs is the epsilon for each order, if they are all the same then we use a fancy formula
Expand Down

0 comments on commit d3429ec

Please sign in to comment.