Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating MTK interface #346

Merged
merged 8 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ jobs:
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/[email protected]
with:
files: ./src ./docs
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ PrecompileTools = "1.2"
Primes = "0.5"
Random = "1.6, 1.7"
SpecialFunctions = "2"
SymbolicUtils = "2"
Symbolics = "5.30.1"
SymbolicUtils = "2, 3"
Symbolics = "5.30.1, 6"
Test = "1.6, 1.7"
TestSetExtensions = "2"
TimerOutputs = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/discrete_time.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ eqs = [
R(k) ~ R(k - 1) + α * I(k - 1),
]

@mtkbuild sys = DiscreteSystem(eqs, t)
@named sys = DiscreteSystem(eqs, t)

assess_local_identifiability(sys, measured_quantities = [I])
```
Expand Down
111 changes: 47 additions & 64 deletions ext/ModelingToolkitSIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic,
return args[1]^args[2]
end
return 1 // args[1]^(-args[2])
# dirty way, assumes that all shifts should be just removed
elseif startswith(String(Symbol(Symbolics.operation(e))), "Shift")
return args[1]
end
throw(Base.ArgumentError("Function $(Symbolics.operation(e)) is not supported"))
elseif e isa Symbolics.Symbolic
Expand All @@ -71,12 +74,11 @@ function StructuralIdentifiability.eval_at_nemo(
end

function get_measured_quantities(ode::ModelingToolkit.ODESystem)
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
return filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(ode),
)
outputs = filter(eq -> ModelingToolkit.isoutput(eq.lhs), ModelingToolkit.equations(ode))
if !isempty(outputs)
return outputs
elseif !isempty(ModelingToolkit.observed(ode))
return ModelingToolkit.observed(ode)
else
throw(
error(
Expand All @@ -103,6 +105,9 @@ function StructuralIdentifiability.mtk_to_si(
de::ModelingToolkit.AbstractTimeDependentSystem,
measured_quantities::Array{ModelingToolkit.Equation},
)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(de)
end
return __mtk_to_si(
de,
[(replace(string(e.lhs), "(t)" => ""), e.rhs) for e in measured_quantities],
Expand Down Expand Up @@ -153,6 +158,20 @@ function preprocess_ode(
return mtk_to_si(de, measured_quantities)
end

#------------------------------------------------------------------------------
function clean_calls(funcs)
res = []
for f in funcs
if length(Symbolics.arguments(f)) == 1 &&
!Symbolics.iscall(first(Symbolics.arguments(f)))
push!(res, f)
else
push!(res, first(Symbolics.arguments(f)))
end
end
return res
end

#------------------------------------------------------------------------------
"""
function __mtk_to_si(de::ModelingToolkit.AbstractTimeDependentSystem, measured_quantities::Array{Tuple{String, SymbolicUtils.BasicSymbolic}})
Expand Down Expand Up @@ -186,11 +205,10 @@ function __mtk_to_si(
end

y_functions = [each[2] for each in measured_quantities]
inputs = filter(v -> ModelingToolkit.isinput(v), ModelingToolkit.unknowns(de))
state_vars = filter(
s -> !(ModelingToolkit.isinput(s) || ModelingToolkit.isoutput(s)),
ModelingToolkit.unknowns(de),
)
state_vars =
filter(s -> !ModelingToolkit.isoutput(s), clean_calls(map(e -> e.lhs, diff_eqs)))
all_funcs = collect(Set(clean_calls(ModelingToolkit.unknowns(de))))
inputs = filter(s -> !ModelingToolkit.isoutput(s), setdiff(all_funcs, state_vars))
params = ModelingToolkit.parameters(de)
t = ModelingToolkit.arguments(diff_eqs[1].lhs)[1]
params_from_measured_quantities = union(
Expand Down Expand Up @@ -240,7 +258,7 @@ function __mtk_to_si(
end
# -----------------------------------------------------------------------------
"""
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)

Input:
- `ode` - the ODESystem object from ModelingToolkit
Expand All @@ -263,7 +281,7 @@ The return value is a tuple consisting of the array of bools and the number of e
"""
function StructuralIdentifiability.assess_local_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = Array{}[],
prob_threshold::Float64 = 0.99,
type = :SE,
Expand All @@ -288,28 +306,13 @@ end
prob_threshold::Float64 = 0.99,
type = :SE,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
measured_quantities = filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(ode),
)
else
throw(
error(
"Measured quantities (output functions) were not provided and no outputs were found.",
),
)
end
end
if length(funcs_to_check) == 0
funcs_to_check = vcat(
[e for e in ModelingToolkit.unknowns(ode) if !ModelingToolkit.isoutput(e)],
ModelingToolkit.parameters(ode),
)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
@info "System parsed into $ode"
conversion_back = Dict(v => k for (k, v) in conversion)
if isempty(funcs_to_check)
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
end

funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]

if isequal(type, :SE)
Expand Down Expand Up @@ -340,7 +343,7 @@ end
# ------------------------------------------------------------------------------

"""
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)

Input:
- `ode` - the ModelingToolkit.ODESystem object that defines the model
Expand All @@ -356,7 +359,7 @@ If known initial conditions are provided, the identifiability results for the st
"""
function StructuralIdentifiability.assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
Expand All @@ -376,16 +379,13 @@ end

function _assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(ode)
end

ode, conversion = mtk_to_si(ode, measured_quantities)
@info "System parsed into $ode"
conversion_back = Dict(v => k for (k, v) in conversion)
if isempty(funcs_to_check)
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
Expand Down Expand Up @@ -470,43 +470,29 @@ function _assess_local_identifiability(
known_ic = Array{}[],
prob_threshold::Float64 = 0.99,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(dds))
@info "Measured quantities are not provided, trying to find the outputs in input dynamical system."
measured_quantities = filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(dds),
)
else
throw(
error(
"Measured quantities (output functions) were not provided and no outputs were found.",
),
)
end
end

# Converting the finite difference operator in the right-hand side to
# the corresponding shift operator
eqs = filter(eq -> !(ModelingToolkit.isoutput(eq.lhs)), ModelingToolkit.equations(dds))

dds_aux_ode, conversion = mtk_to_si(dds, measured_quantities)
dds_aux = StructuralIdentifiability.DDS{QQMPolyRingElem}(dds_aux_ode)
@info "Parsed into the following model: $dds_aux"
if length(funcs_to_check) == 0
params = parameters(dds)
params_from_measured_quantities = union(
[filter(s -> !iscall(s), get_variables(y)) for y in measured_quantities]...,
)
funcs_to_check = vcat(
[
x for x in unknowns(dds) if
x for x in clean_calls(unknowns(dds)) if
conversion[x] in StructuralIdentifiability.x_vars(dds_aux)
],
union(params, params_from_measured_quantities),
)
end
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]
@info "Functions to check are $(["$f" for f in funcs_to_check_]) and initial conditions are known for $(["$f" for f in known_ic_])"

result = StructuralIdentifiability._assess_local_identifiability_discrete_aux(
dds_aux,
Expand Down Expand Up @@ -568,7 +554,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0])
"""
function StructuralIdentifiability.find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
known_ic = [],
prob_threshold::Float64 = 0.99,
seed = 42,
Expand All @@ -595,18 +581,15 @@ end

function _find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
known_ic = Array{Symbolics.Num}[],
measured_quantities = ModelingToolkit.Equation[],
known_ic = Symbolics.Num[],
prob_threshold::Float64 = 0.99,
seed = 42,
with_states = false,
simplify = :standard,
rational_interpolator = :VanDerHoevenLecerf,
)
Random.seed!(seed)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(ode)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]
result = nothing
Expand Down
Loading
Loading