Skip to content

Commit

Permalink
Add support for NPOD
Browse files Browse the repository at this point in the history
  • Loading branch information
Siel committed Nov 27, 2023
1 parent 62cb306 commit 1800d4f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 32 deletions.
2 changes: 1 addition & 1 deletion examples/bimodal_ke/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ log_out = "log/bimodal_ke.log"

[config]
cycles = 1024
engine = "NPAG"
engine = "NPOD"
init_points = 2129
seed = 347
tui = true
Expand Down
3 changes: 1 addition & 2 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ use crate::{
datafile::Scenario,
evaluation::sigma::{ErrorPoly, ErrorType},
ipm,
optimization::expansion::adaptative_grid,
output::NPResult,
output::{CycleLog, NPCycle},
prob, qr,
settings::run::Data,
simulation::predict::Engine,
simulation::predict::{sim_obs, Predict},
},
tui::ui::Comm,
tui::ui::Comm, routines::expansion::adaptative_grid::adaptative_grid,

};

Expand Down
37 changes: 16 additions & 21 deletions src/algorithms/npod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::{
use crate::{prelude::{
algorithms::Algorithm,
condensation::prune::prune,
datafile::Scenario,
Expand All @@ -11,7 +11,7 @@ use crate::prelude::{
settings::run::Data,
simulation::predict::Engine,
simulation::predict::{sim_obs, Predict},
};
}, tui::ui::Comm};
use ndarray::parallel::prelude::*;
use ndarray::{Array, Array1, Array2, Axis};
use ndarray_stats::{DeviationExt, QuantileExt};
Expand All @@ -22,7 +22,7 @@ const THETA_F: f64 = 1e-2;

pub struct NPOD<S>
where
S: Predict + std::marker::Sync + Clone,
S: Predict<'static> + std::marker::Sync + Clone,
{
engine: Engine<S>,
ranges: Vec<(f64, f64)>,
Expand All @@ -41,13 +41,13 @@ where
cache: bool,
scenarios: Vec<Scenario>,
c: (f64, f64, f64, f64),
tx: UnboundedSender<NPCycle>,
tx: UnboundedSender<Comm>,
settings: Data,
}

impl<S> Algorithm for NPOD<S>
where
S: Predict + std::marker::Sync + Clone,
S: Predict<'static> + std::marker::Sync + Clone,
{
fn fit(&mut self) -> NPResult {
self.run()
Expand All @@ -68,7 +68,7 @@ where

impl<S> NPOD<S>
where
S: Predict + std::marker::Sync + Clone,
S: Predict<'static> + std::marker::Sync + Clone,
{
/// Creates a new NPOD instance.
///
Expand All @@ -91,11 +91,11 @@ where
theta: Array2<f64>,
scenarios: Vec<Scenario>,
c: (f64, f64, f64, f64),
tx: UnboundedSender<NPCycle>,
tx: UnboundedSender<Comm>,
settings: Data,
) -> Self
where
S: Predict + std::marker::Sync,
S: Predict<'static> + std::marker::Sync,
{
Self {
engine: sim_eng,
Expand Down Expand Up @@ -230,7 +230,7 @@ where
keep.push(*perm.get(i).unwrap());
}
}
log::info!(
tracing::info!(
"QR decomp, cycle {}, kept: {}, thrown {}",
self.cycle,
keep.len(),
Expand All @@ -249,21 +249,20 @@ where

self.optim_gamma();

let mut state = NPCycle {
let state = NPCycle {
cycle: self.cycle,
objf: -2. * self.objf,
delta_objf: (self.last_objf - self.objf).abs(),
nspp: self.theta.shape()[0],
stop_text: "".to_string(),
theta: self.theta.clone(),
gamlam: self.gamma,
};
self.tx.send(state.clone()).unwrap();
self.tx.send(Comm::NPCycle(state.clone())).unwrap();

// If the objective function decreased, log an error.
// Increasing objf signals instability of model misspecification.
if self.last_objf > self.objf {
log::error!("Objective function decreased");
tracing::error!("Objective function decreased");
}

self.w = self.lambda.clone();
Expand Down Expand Up @@ -293,19 +292,15 @@ where
prune(&mut self.theta, cp, &self.ranges, THETA_D);
}

// Stop if we have reached maximum number of cycles
if self.cycle >= self.settings.parsed.config.cycles {
log::info!("Maximum number of cycles reached");
state.stop_text = "No (max cycle)".to_string();
self.tx.send(state).unwrap();
// Stop if we have reached maximum number of cycles
if self.cycle >= self.settings.parsed.config.cycles {
tracing::warn!("Maximum number of cycles reached");
break;
}

// Stop if stopfile exists
if std::path::Path::new("stop").exists() {
log::info!("Stopfile detected - breaking");
state.stop_text = "No (stopped)".to_string();
self.tx.send(state).unwrap();
tracing::warn!("Stopfile detected - breaking");
break;
}
//TODO: the cycle migh break before reaching this point
Expand Down
11 changes: 5 additions & 6 deletions src/routines/optimization/d_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use argmin::{
core::{
observers::{ObserverMode, SlogLogger},
CostFunction, Error, Executor,
},
solver::neldermead::NelderMead,
Expand All @@ -17,7 +16,7 @@ use crate::prelude::{prob, sigma::Sigma};
pub struct SppOptimizer<'a, S, P>
where
S: Sigma + Sync,
P: Predict + Sync + Clone,
P: Predict<'static> + Sync + Clone,
{
engine: &'a Engine<P>,
scenarios: &'a Vec<Scenario>,
Expand All @@ -28,7 +27,7 @@ where
impl<'a, S, P> CostFunction for SppOptimizer<'a, S, P>
where
S: Sigma + Sync,
P: Predict + Sync + Clone,
P: Predict<'static> + Sync + Clone,
{
type Param = Array1<f64>;
type Output = f64;
Expand All @@ -37,10 +36,10 @@ where
let ypred = sim_obs(&self.engine, &self.scenarios, &theta, true);
let psi = prob::calculate_psi(&ypred, self.scenarios, self.sig);
if psi.ncols() > 1 {
log::error!("Psi in SppOptimizer has more than one column");
tracing::error!("Psi in SppOptimizer has more than one column");
}
if psi.nrows() != self.pyl.len() {
log::error!(
tracing::error!(
"Psi in SppOptimizer has {} rows, but spp has {}",
psi.nrows(),
self.pyl.len()
Expand All @@ -58,7 +57,7 @@ where
impl<'a, S, P> SppOptimizer<'a, S, P>
where
S: Sigma + Sync,
P: Predict + Sync + Clone,
P: Predict<'static> + Sync + Clone,
{
pub fn new(
engine: &'a Engine<P>,
Expand Down
4 changes: 2 additions & 2 deletions src/routines/simulation/predict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ const CACHE_SIZE: usize = 1000000;
pub struct Model {
params: HashMap<String, f64>,
_scenario: Scenario,
infusions: Vec<Infusion>,
cov: Option<HashMap<String, CovLine>>,
_infusions: Vec<Infusion>,
_cov: Option<HashMap<String, CovLine>>,
}
impl Model {
pub fn get_param(&self, str: &str) -> f64 {
Expand Down

0 comments on commit 1800d4f

Please sign in to comment.