Skip to content

Commit

Permalink
Update decision tree (#386)
Browse files Browse the repository at this point in the history
* add <complicated> node to DecisionTree

* change <fever> to <uncomplicated>

* add <cohort> and rename <complicated> to <severe>

* add latest Epiosde::State to DecisionTree CaseType when outside HS

* add uncomplicated optional 'memory' parameter, similar to 'healthSystemMemory'
  • Loading branch information
acavelan authored Jun 20, 2024
1 parent 87bed56 commit ef5b855
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 71 deletions.
5 changes: 1 addition & 4 deletions model/Clinical/CM5DayCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ void CM5DayCommon::init(){
// ——— per-human, construction and destruction ———

CM5DayCommon::CM5DayCommon (double tSF) :
m_tLastTreatment (sim::never()),
m_treatmentSeekingFactor (tSF)
{}


// ——— per-human, update ———

void CM5DayCommon::doClinicalUpdate (Human& human, double ageYears) {
const bool isDoomed = doomed != NOT_DOOMED;
WithinHost::Pathogenesis::StatePair pg = human.withinHostModel->determineMorbidity( human, ageYears, isDoomed );
void CM5DayCommon::doClinicalUpdate (Human& human, double ageYears, WithinHost::Pathogenesis::StatePair &pg) {
Episode::State pgState = static_cast<Episode::State>( pg.state );

if (pgState & Episode::MALARIA) {
Expand Down
5 changes: 1 addition & 4 deletions model/Clinical/CM5DayCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,13 @@ class CM5DayCommon : public ClinicalModel
static double cureRateSevere;
static WithinHost::TreatmentId treatmentSevere;

virtual void doClinicalUpdate (Human& human, double ageYears);
virtual void doClinicalUpdate (Human& human, double ageYears, WithinHost::Pathogenesis::StatePair &pg);

virtual void checkpoint (istream& stream);
virtual void checkpoint (ostream& stream);

/** Called when a non-severe/complicated malaria sickness occurs. */
virtual void uncomplicatedEvent(Human& human, Episode::State pgState) =0;

/** Time of the last treatment (sim::never() if never treated). */
SimTime m_tLastTreatment = sim::never();

//! treatment seeking for heterogeneity
double m_treatmentSeekingFactor;
Expand Down
170 changes: 136 additions & 34 deletions model/Clinical/CMDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "util/errors.h"
#include "util/UnitParse.h"
#include "interventions/Interfaces.hpp"
#include "interventions/InterventionManager.hpp"

#include <limits>
#include <list>
Expand Down Expand Up @@ -139,12 +140,13 @@ class CMDTDiagnostic : public CMDecisionTree {

virtual CMDTOut exec( CMHostData hostData ) const{
CMDTOut result;
if( hostData.withinHost().diagnosticResult( hostData.human.rng, diagnostic ) ){

if( hostData.withinHost().diagnosticResult( hostData.human.rng, diagnostic ) )
result = positive.exec( hostData );
}else{
else
result = negative.exec( hostData );
}
result.screened = true;

return result;
}

Expand All @@ -161,14 +163,14 @@ class CMDTDiagnostic : public CMDecisionTree {
const CMDecisionTree& negative;
};

class CMDTFever : public CMDecisionTree {
class CMDTUncomplicated : public CMDecisionTree {
public:
static const CMDecisionTree& create( const ::scnXml::DTFever& node, bool isUC );
static const CMDecisionTree& create( const ::scnXml::DTUncomplicated& node, bool isUC );

protected:
virtual bool operator==( const CMDecisionTree& that ) const{
if( this == &that ) return true; // short cut: same object thus equivalent
const CMDTFever* p = dynamic_cast<const CMDTFever*>( &that );
const CMDTUncomplicated* p = dynamic_cast<const CMDTUncomplicated*>( &that );
if( p == 0 ) return false; // different type of node
// if( diagnostic != p->diagnostic ) return false;
if( positive != p->positive ) return false;
Expand All @@ -177,35 +179,69 @@ class CMDTFever : public CMDecisionTree {
}

virtual CMDTOut exec( CMHostData hostData ) const{
CMDTOut result;

const Clinical::Episode &latest = hostData.human.clinicalModel->getLatestReport();

// Rely on the health system memory to not count the same episode twice
if( (latest.time + lookBack > sim::now() - sim::oneTS()) && ((latest.state & Episode::MALARIA ) || (latest.state & Episode::SICK)) )
if ( ((hostData.pgState & Episode::SICK) && !(hostData.pgState & Episode::COMPLICATED)) || (hostData.pgState & Episode::MALARIA) )
{
result = positive.exec( hostData );
if (latest.time + memory >= sim::ts0())
return positive.exec( hostData );
else
return negative.exec( hostData );
}
else{
else
return negative.exec( hostData );
}

private:
CMDTUncomplicated(
const SimTime memory,
const CMDecisionTree& positive,
const CMDecisionTree& negative ) :
memory(memory), positive(positive), negative(negative)
{
if(memory > ClinicalModel::hsMemory())
throw util::xml_scenario_error( "<uncomplicated> memory parameter must be less than or equal to the healthsystem memory (hsmemory parameter)" );
}

const SimTime memory;
const CMDecisionTree& positive;
const CMDecisionTree& negative;
};

class CMDTSevere : public CMDecisionTree {
public:
static const CMDecisionTree& create( const ::scnXml::DTSevere& node, bool isUC );

protected:
virtual bool operator==( const CMDecisionTree& that ) const{
if( this == &that ) return true; // short cut: same object thus equivalent
const CMDTSevere* p = dynamic_cast<const CMDTSevere*>( &that );
if( p == 0 ) return false; // different type of node
// if( diagnostic != p->diagnostic ) return false;
if( positive != p->positive ) return false;
if( negative != p->negative ) return false;
return true; // no tests failed; must be the same
}

virtual CMDTOut exec( CMHostData hostData ) const{
CMDTOut result;

if (hostData.pgState & Episode::COMPLICATED)
result = positive.exec( hostData );
else
result = negative.exec( hostData );
}

result.screened = true;
// result.screened = true;
return result;
}

private:
CMDTFever(
const SimTime lookBack,
CMDTSevere(
const CMDecisionTree& positive,
const CMDecisionTree& negative ) :
lookBack(lookBack), positive(positive), negative(negative)
{
if(lookBack > ClinicalModel::hsMemory())
throw util::xml_scenario_error( "<fever> lookBack parameter must be less than or equal to the healthsystem memory (hsmemory parameter)" );
}
positive(positive), negative(negative)
{}

const SimTime lookBack;
const CMDecisionTree& positive;
const CMDecisionTree& negative;
};
Expand Down Expand Up @@ -498,6 +534,47 @@ class CMDTDeploy : public CMDecisionTree, interventions::HumanIntervention {
}
};

class CMDTCohort : public CMDecisionTree {
public:
static const CMDecisionTree& create( const ::scnXml::DTCohort& node, bool isUC );

protected:
virtual bool operator==( const CMDecisionTree& that ) const{
if( this == &that ) return true; // short cut: same object thus equivalent
const CMDTCohort* p = dynamic_cast<const CMDTCohort*>( &that );
if( p == 0 ) return false; // different type of node
// if( diagnostic != p->diagnostic ) return false;
if( positive != p->positive ) return false;
if( negative != p->negative ) return false;
return true; // no tests failed; must be the same
}

virtual CMDTOut exec( CMHostData hostData ) const{
CMDTOut result;

// Rely on the health system memory to not count the same episode twice
if(hostData.human.isInSubPop(component))
result = positive.exec( hostData );
else
result = negative.exec( hostData );

// result.screened = true;
return result;
}

private:
CMDTCohort(
const std::string component,
const CMDecisionTree& positive,
const CMDecisionTree& negative ) :
component(interventions::InterventionManager::getComponentId(component)), positive(positive), negative(negative)
{
}

interventions::ComponentId component;
const CMDecisionTree& positive;
const CMDecisionTree& negative;
};

// ——— static functions ———

Expand All @@ -522,14 +599,17 @@ const CMDecisionTree& save_decision( CMDecisionTree* decision ){
return *decision_library.back();
}

const CMDecisionTree& CMDecisionTree::create( const scnXml::DecisionTree& node, bool isUC ){
if( node.getMultiple().present() ) return CMDTMultiple::create( node.getMultiple().get(), isUC );
const CMDecisionTree& CMDecisionTree::create( const scnXml::DecisionTree& node, bool isUC){
if( node.getMultiple().present() ) return CMDTMultiple::create( node.getMultiple().get(), isUC);
// branching nodes
if( node.getCaseType().present() ) return CMDTCaseType::create( node.getCaseType().get(), isUC );
if( node.getDiagnostic().present() ) return CMDTDiagnostic::create( node.getDiagnostic().get(), isUC );
if( node.getFever().present() ) return CMDTFever::create( node.getFever().get(), isUC );
if( node.getRandom().present() ) return CMDTRandom::create( node.getRandom().get(), isUC );
if( node.getAge().present() ) return CMDTAge::create( node.getAge().get(), isUC );
if( node.getCaseType().present() ) return CMDTCaseType::create( node.getCaseType().get(), isUC);
if( node.getDiagnostic().present() ) return CMDTDiagnostic::create( node.getDiagnostic().get(), isUC);
if( node.getUncomplicated().present() ) return CMDTUncomplicated::create( node.getUncomplicated().get(), isUC);
if( node.getSevere().present() ) return CMDTSevere::create( node.getSevere().get(), true);
if( node.getRandom().present() ) return CMDTRandom::create( node.getRandom().get(), isUC);
if( node.getAge().present() ) return CMDTAge::create( node.getAge().get(), isUC);
if( node.getCohort().present() ) return CMDTCohort::create( node.getCohort().get(), isUC);

// action nodes
if( node.getNoTreatment().present() ) return save_decision( new CMDTNoTreatment() );
if( node.getReport().size() ) return save_decision( new CMDTReport(node.getReport()) );
Expand All @@ -551,15 +631,21 @@ const CMDecisionTree& CMDTMultiple::create( const scnXml::DTMultiple& node, bool
for( const scnXml::DTDiagnostic& sn : node.getDiagnostic() ){
self->children.push_back( &CMDTDiagnostic::create(sn, isUC) );
}
for( const scnXml::DTFever& sn : node.getFever() ){
self->children.push_back( &CMDTFever::create(sn, isUC) );
for( const scnXml::DTUncomplicated& sn : node.getUncomplicated() ){
self->children.push_back( &CMDTUncomplicated::create(sn, isUC) );
}
for( const scnXml::DTSevere& sn : node.getSevere() ){
self->children.push_back( &CMDTSevere::create(sn, isUC) );
}
for( const scnXml::DTRandom& sn : node.getRandom() ){
self->children.push_back( &CMDTRandom::create(sn, isUC) );
}
for( const scnXml::DTAge& sn : node.getAge() ){
self->children.push_back( &CMDTAge::create(sn, isUC) );
}
for( const scnXml::DTCohort& sn : node.getCohort() ){
self->children.push_back( &CMDTCohort::create(sn, isUC) );
}
if( node.getTreatPKPD().size() ){
self->children.push_back( &save_decision(new CMDTTreatPKPD(node.getTreatPKPD())) );
}
Expand Down Expand Up @@ -593,14 +679,30 @@ const CMDecisionTree& CMDTDiagnostic::create( const scnXml::DTDiagnostic& node,
) );
}

const CMDecisionTree& CMDTFever::create( const scnXml::DTFever& node, bool isUC ){
return save_decision( new CMDTFever(
UnitParse::readShortDuration(node.getLookBack(), UnitParse::STEPS),
const CMDecisionTree& CMDTUncomplicated::create( const scnXml::DTUncomplicated& node, bool isUC ){
return save_decision( new CMDTUncomplicated(
UnitParse::readShortDuration(node.getMemory(), UnitParse::STEPS),
CMDecisionTree::create( node.getPositive(), isUC ),
CMDecisionTree::create( node.getNegative(), isUC )
) );
}

const CMDecisionTree& CMDTSevere::create( const scnXml::DTSevere& node, bool isUC ){
return save_decision( new CMDTSevere(
CMDecisionTree::create( node.getPositive(), isUC ),
CMDecisionTree::create( node.getNegative(), isUC )
) );
}

const CMDecisionTree& CMDTCohort::create( const scnXml::DTCohort& node, bool isUC ){
return save_decision( new CMDTCohort(
node.getComponent(),
CMDecisionTree::create( node.getPositive(), isUC ),
CMDecisionTree::create( node.getNegative(), isUC )
) );
}


const CMDecisionTree& CMDTRandom::create(
const scnXml::DTRandom& node, bool isUC )
{
Expand Down
6 changes: 5 additions & 1 deletion model/Clinical/ClinicalModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ void ClinicalModel::update (Human& human, double ageYears, bool newBorn) {
}
}

doClinicalUpdate (human, ageYears);
const bool isDoomed = doomed != NOT_DOOMED;
WithinHost::Pathogenesis::StatePair pg = human.withinHostModel->determineMorbidity( human, ageYears, isDoomed );
latestState = static_cast<Episode::State>(pg.state);

doClinicalUpdate (human, ageYears, pg);
}

void ClinicalModel::updateInfantDeaths( SimTime age ){
Expand Down
11 changes: 9 additions & 2 deletions model/Clinical/ClinicalModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,18 @@ class ClinicalModel
return latestReport;
}

inline const Episode::State &getLatestState () const{
return latestState;
}

/// Checkpointing
template<class S>
void operator& (S& stream) {
checkpoint (stream);
}

/** Time of the last treatment (sim::never() if never treated). */
SimTime m_tLastTreatment = sim::never();

protected:
/// Constructor.
Expand All @@ -123,14 +129,15 @@ class ClinicalModel
* @param hostTransmission per-host transmission data of human.
* @param ageYears Age of human.
* @param ageGroup Survey age group of human. */
virtual void doClinicalUpdate (Human& human, double ageYears) =0;
virtual void doClinicalUpdate (Human& human, double ageYears, WithinHost::Pathogenesis::StatePair &pg) =0;

virtual void checkpoint (istream& stream);
virtual void checkpoint (ostream& stream);

/** Last episode; report to survey pending a new episode or human's death. */
Episode latestReport;

Episode::State latestState;

/** @brief Positive values of _doomed variable (codes). */
enum {
DOOMED_EXPIRED = -35, // codes less than or equal to this mean "dead now"
Expand Down
8 changes: 4 additions & 4 deletions model/Clinical/DecisionTree5Day.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ void DecisionTree5Day::setHealthSystem(const scnXml::HSDT5Day& hsDescription){

// ——— per-human, update ———

void DecisionTree5Day::uncomplicatedEvent ( Human& human, Episode::State pgState ){
latestReport.update (human, Episode::State( pgState ) );

void DecisionTree5Day::uncomplicatedEvent ( Human& human, Episode::State pgState ){
// If last treatment prescribed was in recent memory, consider second line.
CaseType regimen = FirstLine;
if (m_tLastTreatment + healthSystemMemory > sim::ts0()){
pgState = Episode::State (pgState | Episode::SECOND_CASE);
regimen = SecondLine;
}


latestReport.update (human, Episode::State( pgState ) );

double x = human.rng.uniform_01();
if( x < accessUCAny[regimen] * m_treatmentSeekingFactor ){
CMHostData hostData( human, sim::inYears(human.age(sim::ts0())), pgState );
Expand Down
6 changes: 1 addition & 5 deletions model/Clinical/EventScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,8 @@ bool ClinicalEventScheduler::isExistingCase() {
return sim::now() > timeLastTreatment && sim::now() <= timeLastTreatment + healthSystemMemory;
}

void ClinicalEventScheduler::doClinicalUpdate (Human& human, double ageYears){
void ClinicalEventScheduler::doClinicalUpdate (Human& human, double ageYears, WithinHost::Pathogenesis::StatePair &pg){
WHInterface& withinHostModel = *human.withinHostModel;
// Run pathogenesisModel
// Note: we use Episode::COMPLICATED instead of Episode::SEVERE.
const bool isDoomed = doomed != NOT_DOOMED;
WithinHost::Pathogenesis::StatePair pg = human.withinHostModel->determineMorbidity( human, ageYears, isDoomed );
Episode::State newState = static_cast<Episode::State>( pg.state );
util::streamValidate( (newState << 16) & pgState );

Expand Down
2 changes: 1 addition & 1 deletion model/Clinical/EventScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ClinicalEventScheduler : public ClinicalModel
virtual bool isExistingCase();

protected:
virtual void doClinicalUpdate (Human& human, double ageYears);
virtual void doClinicalUpdate (Human& human, double ageYears, WithinHost::Pathogenesis::StatePair &pg);

virtual void checkpoint (istream& stream);
virtual void checkpoint (ostream& stream);
Expand Down
Loading

0 comments on commit ef5b855

Please sign in to comment.