Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov authored and avik-pal committed Oct 13, 2024
1 parent 3274cb5 commit 052dd9a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,17 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
dict_var_span_ = Dict([Symbol(d.variables) => bc for (d, bc) in zip(domains, bc_data)])

bcs_train_sets = map(bound_args) do bt
span = map(b -> get(dict_var_span, b, b), bt)
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
end

pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
pde_args = get_argument(eqs, dict_indvars, dict_depvars)

# pde_train_set = adapt(eltypeθ,
# hcat(vec(map(points -> collect(points),
# Iterators.product(bc_data...)))...))
pde_train_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points),
Iterators.product(bc_data...)))...))

pde_train_sets = map(pde_args) do bt
span = map(b -> get(dict_var_span_, b, b), bt)
Expand Down
4 changes: 2 additions & 2 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ function generate_training_sets(domains, dx, eqs, eltypeθ)
dxs = fill(dx, length(domains))
end
spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
train_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(spans...)))...))
end

function get_loss_function_(loss, init_params, pde_system, strategy::GridTraining)
Expand Down

0 comments on commit 052dd9a

Please sign in to comment.