diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index d9789f7..27c2191 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -23,6 +23,9 @@ class TreatmentEffectBuilder: The TreatmentEffectBuilder class is used to build a TreatmentEffect object with the desired configurations for the treatment effect estimation using a builder pattern. + + By default, it configures a static treatment effect estimation. To configure a + time-dependent treatment effect estimation, the time_attribute must be set. """ treatment: Group @@ -97,6 +100,9 @@ def with_time_attribute( ) -> TreatmentEffectBuilder: """Sets the time attribute to be used in the treatment effect estimation. + It turs the treatment effect estimation from a static to a time-dependent + analysis. + Args: attribute (MedRecordAttribute): The time attribute. diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index 66fc0d4..c0df212 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -1,15 +1,21 @@ """Functions to estimate the treatment effect for continuous outcomes.""" +from __future__ import annotations + import logging from math import sqrt -from typing import Literal, Set +from typing import TYPE_CHECKING, Literal, Optional, Set import numpy as np -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex from medmodels.treatment_effect.temporal_analysis import find_reference_edge +if TYPE_CHECKING: + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.querying import EdgeOperand + from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex + + logger = logging.getLogger(__name__) @@ -20,7 +26,7 @@ def average_treatment_effect( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = None, ) -> float: r"""Calculates the Average Treatment Effect (ATE). @@ -57,8 +63,11 @@ def average_treatment_effect( exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". - time_attribute (MedRecordAttribute, optional): The attribute in the edge that - contains the time information. Defaults to "time". + time_attribute (Optional[MedRecordAttribute], optional): The attribute in the + edge that contains the time information. If it is equal to None, there is + no time component in the data and all edges between the sets and the + outcomes are considered for the average treatment effect. Defaults to + None. Returns: float: The average treatment effect. @@ -66,39 +75,64 @@ def average_treatment_effect( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in treated_outcomes): - msg = "Outcome variable must be numeric" - raise ValueError(msg) + if time_attribute is not None: + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) - control_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in control_outcomes): + else: + edges_treated_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, treatment_outcome_true_set, outcome_group + ) + ) + treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_treated_outcomes + ] + ) + + edges_control_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, control_outcome_true_set, outcome_group + ) + ) + control_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_control_outcomes + ] + ) + + if not all(isinstance(i, (int, float)) for i in treated_outcomes) or not all( + isinstance(i, (int, float)) for i in control_outcomes + ): msg = "Outcome variable must be numeric" raise ValueError(msg) @@ -112,7 +146,7 @@ def cohens_d( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = None, ) -> float: """Calculates Cohen's D, the standardized mean difference between two sets. @@ -147,7 +181,7 @@ def cohens_d( returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". time_attribute (MedRecordAttribute, optional): The attribute in the edge that - contains the time information. Defaults to "time". + contains the time information. Defaults to None. Returns: float: The Cohen's D coefficient, representing the effect size. @@ -158,39 +192,63 @@ def cohens_d( Warning: If the sample size is small (less than 50), the function advises Hedges' g use. """ - treated_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in treated_outcomes): - msg = "Outcome variable must be numeric" - raise ValueError(msg) + if time_attribute is not None: + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) + else: + edges_treated_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, treatment_outcome_true_set, outcome_group + ) + ) + treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_treated_outcomes + ] + ) - control_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in control_outcomes): + edges_control_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, control_outcome_true_set, outcome_group + ) + ) + control_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_control_outcomes + ] + ) + + if not all(isinstance(i, (int, float)) for i in treated_outcomes) or not all( + isinstance(i, (int, float)) for i in control_outcomes + ): msg = "Outcome variable must be numeric" raise ValueError(msg) @@ -213,7 +271,7 @@ def hedges_g( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = None, ) -> float: """Calculates Hedges' g, the unbiased effect size estimate. @@ -236,8 +294,8 @@ def hedges_g( exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". - time_attribute (MedRecordAttribute, optional): The attribute in the edge that - contains the time information. Defaults to "time". + time_attribute (Optional[MedRecordAttribute], optional): The attribute in the + edge that contains the time information. Defaults to None. Returns: float: The Hedges' g coefficient, representing the effect size. @@ -245,39 +303,62 @@ def hedges_g( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in treated_outcomes): - msg = "Outcome variable must be numeric" - raise ValueError(msg) + if time_attribute is not None: + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) + else: + edges_treated_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, treatment_outcome_true_set, outcome_group + ) + ) + treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_treated_outcomes + ] + ) - control_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ] - ) - if not all(isinstance(i, (int, float)) for i in control_outcomes): + edges_control_outcomes = medrecord.select_edges( + lambda edge: query_edges_between_set_outcome( + edge, control_outcome_true_set, outcome_group + ) + ) + control_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_control_outcomes + ] + ) + if not all(isinstance(i, (int, float)) for i in treated_outcomes) or not all( + isinstance(i, (int, float)) for i in control_outcomes + ): msg = "Outcome variable must be numeric" raise ValueError(msg) @@ -285,14 +366,8 @@ def hedges_g( number_control = len(control_outcomes) degrees_of_freedom = number_treated + number_control - 2 - cohen_d = cohens_d( - medrecord, - treatment_outcome_true_set, - control_outcome_true_set, - outcome_group, - outcome_variable, - reference, - time_attribute, + cohen_d = (treated_outcomes.mean() - control_outcomes.mean()) / sqrt( + (treated_outcomes.std(ddof=1) ** 2 + control_outcomes.std(ddof=1) ** 2) / 2 ) # Correction factor J @@ -306,3 +381,26 @@ def hedges_g( "cohens_d": cohens_d, "hedges_g": hedges_g, } + + +def query_edges_between_set_outcome( + edge: EdgeOperand, set: Set[NodeIndex], outcomes_group: Group +) -> None: + """Query edges that connect a set of nodes to the outcomes group. + + Args: + edge (EdgeOperand): The edge operand to query. + set (Set[NodeIndex]): A set of node indices representing the treated group that + also have the outcome. + outcomes_group (Group): The group of nodes that contain the outcome variable. + """ + list_nodes = list(set) + edge.either_or( + lambda edge: edge.source_node().index().is_in(list_nodes), + lambda edge: edge.target_node().index().is_in(list_nodes), + ) + + edge.either_or( + lambda edge: edge.source_node().in_group(outcomes_group), + lambda edge: edge.target_node().in_group(outcomes_group), + ) diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index d365844..dca9372 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +import logging from datetime import datetime from typing import TYPE_CHECKING, List, Optional @@ -192,8 +192,12 @@ def create_medrecord( return medrecord +@pytest.fixture +def medrecord() -> MedRecord: + return create_medrecord() + + def assert_treatment_effects_equal( - test_case: unittest.TestCase, treatment_effect1: TreatmentEffect, treatment_effect2: TreatmentEffect, ) -> None: @@ -271,12 +275,9 @@ def assert_treatment_effects_equal( ) -class TestTreatmentEffect(unittest.TestCase): +class TestTreatmentEffect: """Class to test the TreatmentEffect class in the treatment_effect module.""" - def setUp(self) -> None: - self.medrecord = create_medrecord() - def test_init(self) -> None: # Initialize TreatmentEffect object tee = TreatmentEffect( @@ -291,7 +292,7 @@ def test_init(self) -> None: .build() ) - assert_treatment_effects_equal(self, tee, tee_builder) + assert_treatment_effects_equal(tee, tee_builder) def test_default_properties(self) -> None: tee = TreatmentEffect( @@ -303,32 +304,62 @@ def test_default_properties(self) -> None: TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .with_time_attribute("time") .with_patients_group("patients") .with_washout_period(reference="first") .with_grace_period(days=0, reference="last") - .with_follow_up_period(365, reference="last") + .with_follow_up_period(365000, reference="last") .build() ) - assert_treatment_effects_equal(self, tee, tee_builder) + assert_treatment_effects_equal(tee, tee_builder) + + def test_time_warnings_washout(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + _ = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_washout_period({"Warfarin": 30}) + .build() + ) + + assert ( + "Washout period is not applied because the time attribute is not set." + in caplog.records[0].message + ) - def test_check_medrecord(self) -> None: + def test_time_warnings_follow_up(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + _ = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_follow_up_period(365) + .build() + ) + + assert ( + "Time attribute is not set, thus the grace period, follow-up " + + "period, and outcome before treatment cannot be applied. The " + + "treatment effect analysis is performed in a static way." + ) in caplog.records[0].message + + def test_check_medrecord(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") - .with_treatment("Aspirin") + .with_treatment("no_treatment") .build() ) with pytest.raises( ValueError, match="Treatment group not found in the MedRecord" ): - tee.estimate._check_medrecord(medrecord=self.medrecord) + tee.estimate._check_medrecord(medrecord=medrecord) tee2 = ( TreatmentEffect.builder() - .with_outcome("Headache") + .with_outcome("no_outcome") .with_treatment("Rivaroxaban") .build() ) @@ -336,7 +367,7 @@ def test_check_medrecord(self) -> None: with pytest.raises( ValueError, match="Outcome group not found in the MedRecord" ): - tee2.estimate._check_medrecord(medrecord=self.medrecord) + tee2.estimate._check_medrecord(medrecord=medrecord) patient_group = "subjects" tee3 = ( @@ -351,9 +382,9 @@ def test_check_medrecord(self) -> None: ValueError, match=f"Patient group {patient_group} not found in the MedRecord", ): - tee3.estimate._check_medrecord(medrecord=self.medrecord) + tee3.estimate._check_medrecord(medrecord=medrecord) - def test_find_treated_patients(self) -> None: + def test_find_treated_patients(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -361,11 +392,11 @@ def test_find_treated_patients(self) -> None: .build() ) - treated_set = tee._find_treated_patients(self.medrecord) + treated_set = tee._find_treated_patients(medrecord) assert treated_set == set({"P2", "P3", "P6"}) # no treatment_group - patients = set(self.medrecord.nodes_in_group("patients")) + patients = set(medrecord.nodes_in_group("patients")) medrecord2 = create_medrecord(list(patients - treated_set)) with pytest.raises( @@ -374,17 +405,18 @@ def test_find_treated_patients(self) -> None: ): tee.estimate._compute_subject_counts(medrecord=medrecord2) - def test_query_node_within_time_window(self) -> None: + def test_query_node_within_time_window(self, medrecord: MedRecord) -> None: # check if patient has outcome a year after treatment tee = ( TreatmentEffect.builder() .with_outcome("Stroke") .with_treatment("Rivaroxaban") + .with_time_attribute("time") .build() ) - treated_set = tee._find_treated_patients(self.medrecord) + treated_set = tee._find_treated_patients(medrecord) - nodes = self.medrecord.select_nodes( + nodes = medrecord.select_nodes( lambda node: tee._query_node_within_time_window( node, treated_set, "Stroke", 0, 365, "last" ) @@ -397,7 +429,7 @@ def test_query_node_within_time_window(self) -> None: assert "P6" in treated_set # check which patients have outcome within 30 days after treatment - nodes = self.medrecord.select_nodes( + nodes = medrecord.select_nodes( lambda node: tee._query_node_within_time_window( node, treated_set, "Stroke", 0, 30, "last" ) @@ -408,7 +440,7 @@ def test_query_node_within_time_window(self) -> None: ) # P2 has no outcome in the 30 days window after treatment # If we reduce the window to 3 days, no patients with outcome in that window - nodes = self.medrecord.select_nodes( + nodes = medrecord.select_nodes( lambda node: tee._query_node_within_time_window( node, treated_set, "Stroke", 0, 3, "last" ) @@ -416,7 +448,7 @@ def test_query_node_within_time_window(self) -> None: assert "P3" not in nodes assert "P2" not in nodes - def test_find_groups(self) -> None: + def test_find_groups(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -429,24 +461,65 @@ def test_find_groups(self) -> None: treatment_outcome_false, control_outcome_true, control_outcome_false, - ) = tee._find_groups(self.medrecord) + ) = tee._find_groups(medrecord) + + assert treatment_outcome_true == set({"P2", "P3"}) + assert treatment_outcome_false == set({"P6"}) + assert control_outcome_true == set({"P1", "P4", "P7"}) + assert control_outcome_false == set({"P5", "P8", "P9"}) + + # for this scenario, it works the same in temporal and static analysis + tee = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_time_attribute("time") + .build() + ) + ( + treatment_outcome_true, + treatment_outcome_false, + control_outcome_true, + control_outcome_false, + ) = tee._find_groups(medrecord) + + assert treatment_outcome_true == set({"P2", "P3"}) + assert treatment_outcome_false == set({"P6"}) + assert control_outcome_true == set({"P1", "P4", "P7"}) + assert control_outcome_false == set({"P5", "P8", "P9"}) + + # for this scenario, it works the same in temporal and static analysis + tee = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_time_attribute("time") + .build() + ) + ( + treatment_outcome_true, + treatment_outcome_false, + control_outcome_true, + control_outcome_false, + ) = tee._find_groups(medrecord) + assert treatment_outcome_true == set({"P2", "P3"}) assert treatment_outcome_false == set({"P6"}) assert control_outcome_true == set({"P1", "P4", "P7"}) assert control_outcome_false == set({"P5", "P8", "P9"}) - def test_compute_subject_counts(self) -> None: + def test_compute_subject_counts(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") .build() ) - counts = tee.estimate._compute_subject_counts(self.medrecord) + counts = tee.estimate._compute_subject_counts(medrecord) assert counts == (2, 1, 3, 3) - def test_invalid_compute_subject_counts(self) -> None: + def test_invalid_compute_subject_counts(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -458,7 +531,7 @@ def test_invalid_compute_subject_counts(self) -> None: treatment_outcome_false, control_outcome_true, control_outcome_false, - ) = tee._find_groups(self.medrecord) + ) = tee._find_groups(medrecord) all_patients = set().union( *[ treatment_outcome_true, @@ -494,22 +567,22 @@ def test_invalid_compute_subject_counts(self) -> None: ): tee.estimate._compute_subject_counts(medrecord=medrecord4) - def test_subject_counts(self) -> None: + def test_subject_counts(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") .build() ) + subjects_tee = tee.estimate.subject_counts(medrecord) - subjects_tee = tee.estimate.subject_counts(self.medrecord) assert isinstance(subjects_tee, ContingencyTable) assert subjects_tee["control_outcome_false"] == 3 assert subjects_tee["control_outcome_true"] == 3 assert subjects_tee["treated_outcome_false"] == 1 assert subjects_tee["treated_outcome_true"] == 2 - def test_subjects_indices(self) -> None: + def test_subjects_indices(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -523,36 +596,38 @@ def test_subjects_indices(self) -> None: control_outcome_true={"P1", "P4", "P7"}, control_outcome_false={"P5", "P8", "P9"}, ) - subjects_tee = tee.estimate.subject_indices(self.medrecord) + subjects_tee = tee.estimate.subject_indices(medrecord) assert subjects_tee == subjects_test - def test_follow_up_period(self) -> None: + def test_follow_up_period(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .with_follow_up_period(30) .build() ) assert tee._follow_up_period_days == 30 - counts_tee = tee.estimate._compute_subject_counts(self.medrecord) + counts_tee = tee.estimate._compute_subject_counts(medrecord) assert counts_tee == (1, 2, 3, 3) - def test_grace_period(self) -> None: + def test_grace_period(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_grace_period(10) + .with_time_attribute("time") .build() ) assert tee._grace_period_days == 10 - counts_tee = tee.estimate._compute_subject_counts(self.medrecord) + counts_tee = tee.estimate._compute_subject_counts(medrecord) assert counts_tee == (1, 2, 3, 3) @@ -566,10 +641,14 @@ def test_invalid_grace_period(self) -> None: .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_grace_period(1000) + .with_follow_up_period(365) + .with_time_attribute("time") .build() ) - def test_washout_period(self) -> None: + def test_washout_period( + self, caplog: pytest.LogCaptureFixture, medrecord: MedRecord + ) -> None: washout_dict = {"Warfarin": 30} tee = ( @@ -577,18 +656,24 @@ def test_washout_period(self) -> None: .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_washout_period(washout_dict) + .with_time_attribute("time") .build() ) assert tee._washout_period_days == washout_dict - treated_set = tee._find_treated_patients(self.medrecord) - treated_set, washout_nodes = tee._apply_washout_period( - self.medrecord, treated_set - ) + treated_set = tee._find_treated_patients(medrecord) + with caplog.at_level(logging.WARNING): + treated_set, washout_nodes = tee._apply_washout_period( + medrecord, treated_set + ) assert treated_set == set({"P3", "P6"}) assert washout_nodes == set({"P2"}) + assert ( + "1 subject was dropped due to having a treatment in the washout period." + in caplog.records[0].message + ) # smaller washout period washout_dict2 = {"Warfarin": 10} @@ -598,30 +683,32 @@ def test_washout_period(self) -> None: .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_washout_period(washout_dict2) + .with_time_attribute("time") .build() ) assert tee2._washout_period_days == washout_dict2 - treated_set = tee2._find_treated_patients(self.medrecord) - treated_set, washout_nodes = tee2._apply_washout_period( - self.medrecord, treated_set - ) + treated_set = tee2._find_treated_patients(medrecord) + treated_set, washout_nodes = tee2._apply_washout_period(medrecord, treated_set) assert treated_set == set({"P2", "P3", "P6"}) assert washout_nodes == set({}) - def test_outcome_before_treatment(self) -> None: + def test_outcome_before_treatment( + self, caplog: pytest.LogCaptureFixture, medrecord: MedRecord + ) -> None: # case 1 find outcomes for default tee tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .build() ) - treated_set = tee._find_treated_patients(self.medrecord) + treated_set = tee._find_treated_patients(medrecord) treated_set, treatment_outcome_true, outcome_before_treatment_nodes = ( - tee._find_outcomes(self.medrecord, treated_set) + tee._find_outcomes(medrecord, treated_set) ) assert treated_set == set({"P2", "P3", "P6"}) @@ -633,69 +720,79 @@ def test_outcome_before_treatment(self) -> None: TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .with_outcome_before_treatment_exclusion(30) .build() ) assert tee2._outcome_before_treatment_days == 30 - treated_set = tee2._find_treated_patients(self.medrecord) - treated_set, treatment_outcome_true, outcome_before_treatment_nodes = ( - tee2._find_outcomes(self.medrecord, treated_set) - ) + treated_set = tee2._find_treated_patients(medrecord) + with caplog.at_level(logging.WARNING): + treated_set, treatment_outcome_true, outcome_before_treatment_nodes = ( + tee2._find_outcomes(medrecord, treated_set) + ) assert treated_set == set({"P2", "P6"}) assert treatment_outcome_true == set({"P2"}) assert outcome_before_treatment_nodes == set({"P3"}) + assert ( + "1 subject was dropped due to having an outcome before the treatment." + in caplog.records[0].message + ) + # case 3 no outcome - self.medrecord.add_group("Headache") + medrecord.add_group("no_outcome") tee3 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") - .with_outcome("Headache") + .with_outcome("no_outcome") + .with_time_attribute("time") .with_outcome_before_treatment_exclusion(30) .build() ) with pytest.raises( - ValueError, match="No outcomes found in the MedRecord for group Headache" + ValueError, match="No outcomes found in the MedRecord for group no_outcome" ): - tee3._find_outcomes(medrecord=self.medrecord, treated_set=treated_set) + tee3._find_outcomes(medrecord=medrecord, treated_set=treated_set) - def test_filter_controls(self) -> None: - def query1(node: NodeOperand) -> None: + def test_filter_controls(self, medrecord: MedRecord) -> None: + def query_neighbors_to_m2(node: NodeOperand) -> None: node.neighbors(EdgeDirection.BOTH).index().equal_to("M2") tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(query1) + .with_time_attribute("time") + .filter_controls(query_neighbors_to_m2) .build() ) - counts_tee = tee.estimate._compute_subject_counts(self.medrecord) + counts_tee = tee.estimate._compute_subject_counts(medrecord) assert counts_tee == (2, 1, 1, 2) # filter females only - def query2(node: NodeOperand) -> None: + def query_female_patients(node: NodeOperand) -> None: node.attribute("gender").equal_to("female") tee2 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(query2) + .with_time_attribute("time") + .filter_controls(query_female_patients) .build() ) - counts_tee2 = tee2.estimate._compute_subject_counts(self.medrecord) + counts_tee2 = tee2.estimate._compute_subject_counts(medrecord) assert counts_tee2 == (2, 1, 1, 1) - def test_nearest_neighbors(self) -> None: + def test_nearest_neighbors(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -704,7 +801,7 @@ def test_nearest_neighbors(self) -> None: .build() ) - subjects = tee.estimate.subject_indices(self.medrecord) + subjects = tee.estimate.subject_indices(medrecord) # Multiple patients are equally similar to the treatment group # These are exact macthes and should always be included @@ -712,7 +809,7 @@ def test_nearest_neighbors(self) -> None: assert "P5" in subjects["control_outcome_false"] assert "P8" in subjects["control_outcome_false"] - def test_propensity_matching(self) -> None: + def test_propensity_matching(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -721,13 +818,13 @@ def test_propensity_matching(self) -> None: .build() ) - subjects = tee.estimate.subject_indices(self.medrecord) + subjects = tee.estimate.subject_indices(medrecord) assert "P4" in subjects["control_outcome_true"] assert "P5" in subjects["control_outcome_false"] assert "P1" in subjects["control_outcome_true"] - def test_find_controls(self) -> None: + def test_find_controls(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -735,11 +832,11 @@ def test_find_controls(self) -> None: .build() ) - patients = set(self.medrecord.nodes_in_group("patients")) + patients = set(medrecord.nodes_in_group("patients")) treated_set = {"P2", "P3", "P6"} control_outcome_true, control_outcome_false = tee._find_controls( - self.medrecord, + medrecord, control_set=patients - treated_set, treated_set=patients.intersection(treated_set), ) @@ -751,7 +848,7 @@ def test_find_controls(self) -> None: ValueError, match="No patients found for control groups in this MedRecord" ): tee._find_controls( - self.medrecord, + medrecord, control_set=patients - treated_set, treated_set=patients.intersection(treated_set), rejected_nodes=patients - treated_set, @@ -760,22 +857,22 @@ def test_find_controls(self) -> None: tee2 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") - .with_outcome("Headache") + .with_outcome("no_outcome") .build() ) - self.medrecord.add_group("Headache") + medrecord.add_group("no_outcome") with pytest.raises( - ValueError, match="No outcomes found in the MedRecord for group Headache" + ValueError, match="No outcomes found in the MedRecord for group no_outcome" ): tee2._find_controls( - self.medrecord, + medrecord, control_set=patients - treated_set, treated_set=patients.intersection(treated_set), ) - def test_metrics(self) -> None: + def test_metrics(self, medrecord: MedRecord) -> None: """Test the metrics of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -785,16 +882,14 @@ def test_metrics(self) -> None: ) # Calculate metrics - assert tee.estimate.absolute_risk_reduction(self.medrecord) == pytest.approx( - -1 / 6 - ) - assert tee.estimate.relative_risk(self.medrecord) == pytest.approx(4 / 3) - assert tee.estimate.odds_ratio(self.medrecord) == pytest.approx(2) - assert tee.estimate.confounding_bias(self.medrecord) == pytest.approx(22 / 21) - assert tee.estimate.hazard_ratio(self.medrecord) == pytest.approx(4 / 3) - assert tee.estimate.number_needed_to_treat(self.medrecord) == pytest.approx(-6) - - def test_full_report(self) -> None: + assert tee.estimate.absolute_risk_reduction(medrecord) == pytest.approx(-1 / 6) + assert tee.estimate.relative_risk(medrecord) == pytest.approx(4 / 3) + assert tee.estimate.odds_ratio(medrecord) == pytest.approx(2) + assert tee.estimate.confounding_bias(medrecord) == pytest.approx(22 / 21) + assert tee.estimate.hazard_ratio(medrecord) == pytest.approx(4 / 3) + assert tee.estimate.number_needed_to_treat(medrecord) == pytest.approx(-6) + + def test_full_report(self, medrecord: MedRecord) -> None: """Test the full reporting of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -804,23 +899,19 @@ def test_full_report(self) -> None: ) # Calculate metrics - full_report = tee.report.full_report(self.medrecord) + full_report = tee.report.full_report(medrecord) report_test = { - "absolute_risk_reduction": tee.estimate.absolute_risk_reduction( - self.medrecord - ), - "relative_risk": tee.estimate.relative_risk(self.medrecord), - "odds_ratio": tee.estimate.odds_ratio(self.medrecord), - "confounding_bias": tee.estimate.confounding_bias(self.medrecord), - "hazard_ratio": tee.estimate.hazard_ratio(self.medrecord), - "number_needed_to_treat": tee.estimate.number_needed_to_treat( - self.medrecord - ), + "absolute_risk_reduction": tee.estimate.absolute_risk_reduction(medrecord), + "relative_risk": tee.estimate.relative_risk(medrecord), + "odds_ratio": tee.estimate.odds_ratio(medrecord), + "confounding_bias": tee.estimate.confounding_bias(medrecord), + "hazard_ratio": tee.estimate.hazard_ratio(medrecord), + "number_needed_to_treat": tee.estimate.number_needed_to_treat(medrecord), } assert full_report == report_test - def test_continuous_estimators_report(self) -> None: + def test_continuous_estimators_report(self, medrecord: MedRecord) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -830,22 +921,40 @@ def test_continuous_estimators_report(self) -> None: report_test = { "average_treatment_effect": tee.estimate.average_treatment_effect( - self.medrecord, + medrecord, outcome_variable="intensity", ), - "cohens_d": tee.estimate.cohens_d( - self.medrecord, outcome_variable="intensity" - ), - "hedges_g": tee.estimate.hedges_g( - self.medrecord, outcome_variable="intensity" + "cohens_d": tee.estimate.cohens_d(medrecord, outcome_variable="intensity"), + "hedges_g": tee.estimate.hedges_g(medrecord, outcome_variable="intensity"), + } + + assert report_test == tee.report.continuous_estimators_report( + medrecord, outcome_variable="intensity" + ) + + def test_continuous_estimators_report_with_time(self, medrecord: MedRecord) -> None: + """Test the continuous report of the TreatmentEffect with time attribute.""" + tee = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_time_attribute("time") + .build() + ) + + report_test = { + "average_treatment_effect": tee.estimate.average_treatment_effect( + medrecord, + outcome_variable="intensity", ), + "cohens_d": tee.estimate.cohens_d(medrecord, outcome_variable="intensity"), + "hedges_g": tee.estimate.hedges_g(medrecord, outcome_variable="intensity"), } assert report_test == tee.report.continuous_estimators_report( - self.medrecord, outcome_variable="intensity" + medrecord, outcome_variable="intensity" ) if __name__ == "__main__": - run_test = unittest.TestLoader().loadTestsFromTestCase(TestTreatmentEffect) - unittest.TextTestRunner(verbosity=2).run(run_test) + pytest.main(["-v", __file__]) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index b28cbe3..13fca49 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -6,6 +6,10 @@ without undergoing the treatment. The class supports customizable criteria filtering, time constraints between treatment and outcome, and optional matching of control groups to treatment groups using a specified matching class. + +The default TreatmentEffect class performs an static analysis without considering time. +To perform a time-based analysis, users can specify a time attribute in the +configuration and set the washout period, grace period, and follow-up period. """ from __future__ import annotations @@ -58,7 +62,7 @@ class TreatmentEffect: _outcomes_group: Group _patients_group: Group - _time_attribute: MedRecordAttribute + _time_attribute: Optional[MedRecordAttribute] _washout_period_days: Dict[Group, int] _washout_period_reference: Literal["first", "last"] @@ -113,12 +117,12 @@ def _set_configuration( treatment: Group, outcome: Group, patients_group: Group = "patients", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = None, washout_period_days: Optional[Dict[Group, int]] = None, washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, grace_period_reference: Literal["first", "last"] = "last", - follow_up_period_days: int = 365, + follow_up_period_days: int = 1000 * 365, follow_up_period_reference: Literal["first", "last"] = "last", outcome_before_treatment_days: Optional[int] = None, filter_controls_query: Optional[NodeQuery] = None, @@ -142,18 +146,20 @@ def _set_configuration( outcome (Group): The group of outcomes to analyze. patients_group (Group, optional): The group of patients to analyze. Defaults to "patients". - time_attribute (MedRecordAttribute, optional): The time attribute to use for - time-based analysis. Defaults to "time". - washout_period_days (Dict[Group, int], optional): The washout period in days - for each treatment group. Defaults to dict(). - washout_period_reference (Literal["first", "last"], optional): The reference - point for the washout period. Defaults to "first". + time_attribute (Optional[MedRecordAttribute], optional): The time + attribute. If None, the treatment effect analysis is performed in an + static way (without considering time). Defaults to None. + washout_period_days (Dict[str, int], optional): The washout period in days + for each treatment group. In the case of no time attribute, it is not + applied. Defaults to dict(). + washout_period_reference (Literal["first", "last"], optional): The + reference point for the washout period. Defaults to "first". grace_period_days (int, optional): The grace period in days after the treatment. Defaults to 0. grace_period_reference (Literal["first", "last"], optional): The reference point for the grace period. Defaults to "last". follow_up_period_days (int, optional): The follow-up period in days after - the treatment. Defaults to 365. + the treatment. Defaults to 365000. follow_up_period_reference (Literal["first", "last"], optional): The reference point for the follow-up period. Defaults to "last". outcome_before_treatment_days (Optional[int], optional): The number of days @@ -214,6 +220,23 @@ def _set_configuration( treatment_effect._matching_number_of_neighbors = matching_number_of_neighbors treatment_effect._matching_hyperparameters = matching_hyperparameters + if washout_period_days and not time_attribute: + logger.warning( + "Washout period is not applied because the time attribute is not set." + ) + + if ( + grace_period_days + or (follow_up_period_days != 1000 * 365) + or outcome_before_treatment_days + ) and not time_attribute: + msg = ( + "Time attribute is not set, thus the grace period, follow-up " + + "period, and outcome before treatment cannot be applied. The " + + "treatment effect analysis is performed in a static way." + ) + logger.warning(msg) + def _find_groups( self, medrecord: MedRecord ) -> Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: @@ -239,7 +262,14 @@ def _find_groups( """ # Find patients that underwent the treatment treated_set = self._find_treated_patients(medrecord) - treated_set, washout_nodes = self._apply_washout_period(medrecord, treated_set) + + if self._time_attribute: + treated_set, washout_nodes = self._apply_washout_period( + medrecord, treated_set + ) + else: + washout_nodes = set() + treated_set, treated_outcome_true, outcome_before_treatment_nodes = ( self._find_outcomes(medrecord, treated_set) ) @@ -325,7 +355,7 @@ def _find_outcomes( msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError(msg) - if outcome_before_treatment_days: + if outcome_before_treatment_days and self._time_attribute: outcome_before_treatment_nodes = set( medrecord.select_nodes( lambda node: self._query_node_within_time_window( @@ -341,24 +371,31 @@ def _find_outcomes( treated_set -= outcome_before_treatment_nodes dropped_num = len(outcome_before_treatment_nodes) - logger.warning( - "%d subject%s dropped due to outcome before treatment.", - dropped_num, - " was" if dropped_num == 1 else "s were", + msg = ( + f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " + f"dropped due to having an outcome before the treatment." ) + logger.warning(msg) - treated_outcome_true = set( - medrecord.select_nodes( - lambda node: self._query_node_within_time_window( - node, - treated_set, - self._outcomes_group, - self._grace_period_days, - self._follow_up_period_days, - self._follow_up_period_reference, + if self._time_attribute: + treated_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_node_within_time_window( + node, + treated_set, + self._outcomes_group, + self._grace_period_days, + self._follow_up_period_days, + self._follow_up_period_reference, + ) + ) + ) + else: + treated_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_set_outcome_true(node, treated_set) ) ) - ) return treated_set, treated_outcome_true, outcome_before_treatment_nodes @@ -404,11 +441,11 @@ def _apply_washout_period( if washout_nodes: dropped_num = len(washout_nodes) - logger.warning( - "%d subject%s dropped due to outcome before treatment.", - dropped_num, - " was" if dropped_num == 1 else "s were", + msg = ( + f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " + f"dropped due to having a treatment in the washout period." ) + logger.warning(msg) return treated_set, washout_nodes @@ -474,18 +511,26 @@ def _find_controls( msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError(msg) - def query(node: NodeOperand) -> None: - node.index().is_in(list(control_set)) - node.neighbors(edge_direction=EdgeDirection.BOTH).in_group( - self._outcomes_group - ) - # Finding the patients that had the outcome in the control group - control_outcome_true = set(medrecord.select_nodes(query)) + control_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_set_outcome_true(node, control_set) + ) + ) control_outcome_false = control_set - control_outcome_true return control_outcome_true, control_outcome_false + def _query_set_outcome_true(self, node: NodeOperand, set: Set[NodeIndex]) -> None: + """Query for nodes that are in the given set and have the outcome. + + Args: + node (NodeOperand): The node to query. + set (Set[NodeIndex]): The set of nodes to query. + """ + node.index().is_in(list(set)) + node.neighbors(edge_direction=EdgeDirection.BOTH).in_group(self._outcomes_group) + def _query_node_within_time_window( self, node: NodeOperand, @@ -518,8 +563,14 @@ def _query_node_within_time_window( event. reference (Literal["first", "last"]): The reference point for the time window. + + Raises: + ValueError: If the time attribute is not set. """ node.index().is_in(list(treated_set)) + if self._time_attribute is None: + msg = "Time attribute is not set." + raise ValueError(msg) edges_to_treatment = node.edges() edges_to_treatment.attribute(self._time_attribute).is_datetime()