Skip to content

Commit

Permalink
Robustify the timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Sep 27, 2024
1 parent 81202e2 commit f4ec370
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 32 deletions.
94 changes: 64 additions & 30 deletions examples/compare_imdp_approaches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ function benchmark_direct(problem::ComparisonProblem)

return Dict(
"oom" => false,
"timeout" => true,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand All @@ -119,6 +120,7 @@ function benchmark_direct(problem::ComparisonProblem)

return Dict(
"oom" => true,
"timeout" => false,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand All @@ -140,10 +142,10 @@ function benchmark_direct(problem::ComparisonProblem)
if reach == ""
reach = []
else
reach_lines = split(chomp(reach), '\n')
reach = map(line -> parse(Int32, line), reach_lines) # Parse each line as an integer
reach = reach .- 1 # Subtract 1 to match 1-based indexing without the avoid state
reach = map(x -> cartesian_indices[x], reach) # Convert from a linear index to a CartesianIndex
reach_lines = split(chomp(reach), '\n')
reach = map(line -> parse(Int32, line), reach_lines) # Parse each line as an integer
reach = reach .- 1 # Subtract 1 to match 1-based indexing without the avoid state
reach = map(x -> cartesian_indices[x], reach) # Convert from a linear index to a CartesianIndex
end

# Read avoid states
Expand All @@ -163,6 +165,7 @@ function benchmark_direct(problem::ComparisonProblem)

return Dict(
"oom" => false,
"timeout" => false,
"abstraction_time" => abstraction_time,
"certification_time" => certification_time,
"prob_mem" => prob_mem,
Expand All @@ -174,6 +177,7 @@ function benchmark_direct(problem::ComparisonProblem)

return Dict(
"oom" => true,
"timeout" => false,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand All @@ -196,6 +200,7 @@ function benchmark_decoupled(problem::ComparisonProblem)

return Dict(
"oom" => false,
"timeout" => true,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand Down Expand Up @@ -230,14 +235,14 @@ function benchmark_decoupled(problem::ComparisonProblem)
if reach == ""
reach = []
else
reach_lines = split(chomp(reach), '\n')
reach = map(reach_lines) do line # Parse each line as a tuple of indices
indices = split(line[2:end - 1], ",")
reach_lines = split(chomp(reach), '\n')
reach = map(reach_lines) do line # Parse each line as a tuple of indices
indices = split(line[2:end - 1], ",")
println(indices)
indices = map(index -> parse(Int32, index), indices)
return Tuple(indices)
end
reach = map(x -> CartesianIndex(x .- 1), reach) # Subtract 1 to match 1-based indexing without the avoid states
indices = map(index -> parse(Int32, index), indices)
return Tuple(indices)
end
reach = map(x -> CartesianIndex(x .- 1), reach) # Subtract 1 to match 1-based indexing without the avoid states
end

# Read avoid states
Expand All @@ -260,6 +265,7 @@ function benchmark_decoupled(problem::ComparisonProblem)

return Dict(
"oom" => false,
"timeout" => false,
"abstraction_time" => abstraction_time,
"certification_time" => certification_time,
"prob_mem" => prob_mem,
Expand All @@ -271,6 +277,7 @@ function benchmark_decoupled(problem::ComparisonProblem)

return Dict(
"oom" => true,
"timeout" => false,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand Down Expand Up @@ -301,7 +308,8 @@ function benchmark()
"input_split" => problem.input_split,
"direct" => direct,
"decoupled" => decoupled,
"impact" => impact
"impact" => impact,
"include_impact" => problem.include_impact
)

save_results(problem.name, res)
Expand Down Expand Up @@ -346,15 +354,28 @@ function to_dataframe(res)
"decoupled_max_prob" => maximum(data["decoupled"]["value_function"]),
)

if !data["direct"]["oom"] && n_decoupled == length(data["direct"]["value_function"])
row["direct_abstraction_time"] = data["direct"]["abstraction_time"]
row["direct_certification_time"] = data["direct"]["certification_time"]
row["direct_prob_mem"] = data["direct"]["prob_mem"]
row["direct_min_prob"] = minimum(data["direct"]["value_function"])
row["direct_max_prob"] = maximum(data["direct"]["value_function"])
row["direct_min_prob_diff"] = minimum(data["decoupled"]["value_function"] - data["direct"]["value_function"])
row["direct_max_prob_diff"] = maximum(data["decoupled"]["value_function"] - data["direct"]["value_function"])
row["direct_avg_prob_diff"] = mean(data["decoupled"]["value_function"] - data["direct"]["value_function"])
if !data["direct"]["oom"] && !data["direct"]["timeout"]
if n_decoupled != length(data["direct"]["value_function"])
@warn "Direct value function length mismatch" name n_decoupled n_direct=length(data["direct"]["value_function"])

row["direct_abstraction_time"] = NaN
row["direct_certification_time"] = NaN
row["direct_prob_mem"] = NaN
row["direct_min_prob"] = NaN
row["direct_max_prob"] = NaN
row["direct_min_prob_diff"] = NaN
row["direct_max_prob_diff"] = NaN
row["direct_avg_prob_diff"] = NaN
else
row["direct_abstraction_time"] = data["direct"]["abstraction_time"]
row["direct_certification_time"] = data["direct"]["certification_time"]
row["direct_prob_mem"] = data["direct"]["prob_mem"]
row["direct_min_prob"] = minimum(data["direct"]["value_function"])
row["direct_max_prob"] = maximum(data["direct"]["value_function"])
row["direct_min_prob_diff"] = minimum(data["decoupled"]["value_function"] - data["direct"]["value_function"])
row["direct_max_prob_diff"] = maximum(data["decoupled"]["value_function"] - data["direct"]["value_function"])
row["direct_avg_prob_diff"] = mean(data["decoupled"]["value_function"] - data["direct"]["value_function"])
end
else
row["direct_abstraction_time"] = NaN
row["direct_certification_time"] = NaN
Expand All @@ -366,15 +387,28 @@ function to_dataframe(res)
row["direct_avg_prob_diff"] = NaN
end

if !data["impact"]["oom"] && n_decoupled == length(data["impact"]["value_function"])
row["impact_abstraction_time"] = data["impact"]["abstraction_time"]
row["impact_certification_time"] = data["impact"]["certification_time"]
row["impact_prob_mem"] = data["impact"]["prob_mem"]
row["impact_min_prob"] = minimum(data["impact"]["value_function"])
row["impact_max_prob"] = maximum(data["impact"]["value_function"])
row["impact_min_prob_diff"] = minimum(data["decoupled"]["value_function"] - data["impact"]["value_function"])
row["impact_max_prob_diff"] = maximum(data["decoupled"]["value_function"] - data["impact"]["value_function"])
row["impact_avg_prob_diff"] = mean(data["decoupled"]["value_function"] - data["impact"]["value_function"])
if data["include_impact"] && !data["impact"]["oom"] && !data["impact"]["timeout"]
if n_decoupled != length(data["impact"]["value_function"])
@warn "Impact value function length mismatch" name n_decoupled n_impact=length(data["impact"]["value_function"])

row["impact_abstraction_time"] = NaN
row["impact_certification_time"] = NaN
row["impact_prob_mem"] = NaN
row["impact_min_prob"] = NaN
row["impact_max_prob"] = NaN
row["impact_min_prob_diff"] = NaN
row["impact_max_prob_diff"] = NaN
row["impact_avg_prob_diff"] = NaN
else
row["impact_abstraction_time"] = data["impact"]["abstraction_time"]
row["impact_certification_time"] = data["impact"]["certification_time"]
row["impact_prob_mem"] = data["impact"]["prob_mem"]
row["impact_min_prob"] = minimum(data["impact"]["value_function"])
row["impact_max_prob"] = maximum(data["impact"]["value_function"])
row["impact_min_prob_diff"] = minimum(data["decoupled"]["value_function"] - data["impact"]["value_function"])
row["impact_max_prob_diff"] = maximum(data["decoupled"]["value_function"] - data["impact"]["value_function"])
row["impact_avg_prob_diff"] = mean(data["decoupled"]["value_function"] - data["impact"]["value_function"])
end
else
row["impact_abstraction_time"] = NaN
row["impact_certification_time"] = NaN
Expand Down
4 changes: 4 additions & 0 deletions examples/systems/IMPaCT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ function run_impact(name; lower_bound=true, container=:apptainer)

return Dict(
"oom" => false,
"timeout" => true,
"abstraction_time" => NaN,
"certification_time" => NaN,
"prob_mem" => NaN,
Expand All @@ -29,6 +30,7 @@ function run_impact(name; lower_bound=true, container=:apptainer)
if !occursin("Finding control policy", output)
return Dict(
"oom" => true,
"timeout" => false,
"abstraction_time" => NaN,
"certification_time" => NaN,
"peak_mem" => NaN,
Expand Down Expand Up @@ -81,6 +83,7 @@ function run_impact(name; lower_bound=true, container=:apptainer)

return Dict(
"oom" => false,
"timeout" => false,
"abstraction_time" => abstraction_time,
"certification_time" => certification_time,
"prob_mem" => mem,
Expand All @@ -92,6 +95,7 @@ function run_impact(name; lower_bound=true, container=:apptainer)

return Dict(
"oom" => true,
"timeout" => false,
"abstraction_time" => NaN,
"certification_time" => NaN,
"peak_mem" => NaN,
Expand Down
129 changes: 128 additions & 1 deletion examples/systems/nndm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ function cartpole_sys()
return sys
end


function cartpole_decoupled(; sparse=false)
sys = cartpole_sys()

Expand All @@ -70,5 +69,133 @@ function cartpole_decoupled(; sparse=false)

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end

function cartpole_direct(; sparse=false)
sys = cartpole_sys()

X = Hyperrectangle(; low=[-1.0, -0.5, deg2rad(-12.0), -0.5], high=[1.0, 0.5, deg2rad(12.0), 0.5])
state_split = (10, 4, 24, 4)
state_abs = StateUniformGridSplit(X, state_split)

input_abs = InputDiscrete([Universe(0)])

if sparse
target_model = SparseIMDPTarget()
else
target_model = IMDPTarget()
end

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end

function husky4d_sys()
pwa_dyn = load_system("husky4d", 4800)
w = AdditiveDiagonalGaussianNoise([0.01, 0.01, 0.01, 0.01])
dyn = UncertainPWAAdditiveNoiseDynamics(4, 0, pwa_dyn, w)

initial = EmptySet(4)
reach = EmptySet(4)
avoid = EmptySet(4)

sys = System(dyn, initial, reach, avoid)

return sys
end

function husky4d_sys_decoupled(; sparse=false)
sys = husky4d_sys()

X = Hyperrectangle(; low=[-0.5, -1.0, deg2rad(-15.0), -0.5], high=[2.0, 1.0, deg2rad(15.0), 0.5])
state_split = (10, 8, 15, 4)
state_abs = StateUniformGridSplit(X, state_split)

input_abs = InputDiscrete([Universe(0)])

if sparse
target_model = SparseOrthogonalIMDPTarget()
else
target_model = OrthogonalIMDPTarget()
end

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end

function husky4d_sys_direct(; sparse=false)
sys = husky4d_sys()

X = Hyperrectangle(; low=[-0.5, -1.0, deg2rad(-15.0), -0.5], high=[2.0, 1.0, deg2rad(15.0), 0.5])
state_split = (10, 8, 15, 4)
state_abs = StateUniformGridSplit(X, state_split)

input_abs = InputDiscrete([Universe(0)])

if sparse
target_model = SparseIMDPTarget()
else
target_model = IMDPTarget()
end

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end

function husky5d_sys()
pwa_dyn = load_system("husky5d", 1728)
w = AdditiveDiagonalGaussianNoise([0.01, 0.01, 0.01, 0.01, 0.01])
dyn = UncertainPWAAdditiveNoiseDynamics(5, 0, pwa_dyn, w)

initial = EmptySet(5)
reach = EmptySet(5)
avoid = EmptySet(5)

sys = System(dyn, initial, reach, avoid)

return sys
end

function husky5d_sys_decoupled(; sparse=false)
sys = husky5d_sys()

X = Hyperrectangle(; low=[-0.5, -0.5, deg2rad(-10.0), -0.5, -0.5], high=[1.9, 0.3, deg2rad(8.0), 0.5, 0.5])
state_split = (6, 2, 9, 4, 4)
state_abs = StateUniformGridSplit(X, state_split)

input_abs = InputDiscrete([Universe(0)])

if sparse
target_model = SparseOrthogonalIMDPTarget()
else
target_model = OrthogonalIMDPTarget()
end

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end

function husky5d_sys_direct(; sparse=false)
sys = husky5d_sys()

X = Hyperrectangle(; low=[-0.5, -0.5, deg2rad(-10.0), -0.5, -0.5], high=[1.9, 0.3, deg2rad(8.0), 0.5, 0.5])
state_split = (6, 2, 9, 4, 4)
state_abs = StateUniformGridSplit(X, state_split)

input_abs = InputDiscrete([Universe(0)])

if sparse
target_model = SparseIMDPTarget()
else
target_model = IMDPTarget()
end

mdp, reach, avoid = abstraction(sys, state_abs, input_abs, target_model)

return mdp, reach, avoid
end
1 change: 1 addition & 0 deletions examples/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ include("bas_7d.jl")

include("integrator_chain.jl")
include("van_der_pol.jl")
include("nndm.jl")

include("IMPaCT.jl")
3 changes: 2 additions & 1 deletion src/dynamics/UncertainPWAAdditiveNoiseDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ function nominal(dyn::UncertainPWAAdditiveNoiseDynamics, x::AbstractVector, u::A
end

throw(ArgumentError("The state is not in the domain of the dynamics"))
end
end
prepare_nominal(::UncertainPWAAdditiveNoiseDynamics, input_abstraction) = nothing

0 comments on commit f4ec370

Please sign in to comment.