From 3c9ddc069629a93952d0ae306318958204081126 Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Thu, 5 Dec 2024 12:30:47 +0100 Subject: [PATCH 1/3] feat: making static analysis as the default option for the TEE --- medmodels/treatment_effect/builder.py | 5 +- .../treatment_effect/continuous_estimators.py | 216 +++++++++++++----- .../tests/test_treatment_effect.py | 74 +++++- .../treatment_effect/treatment_effect.py | 81 +++++-- 4 files changed, 278 insertions(+), 98 deletions(-) diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 15ceb4c..5d38fba 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -82,12 +82,13 @@ def with_patients_group(self, group: Group) -> TreatmentEffectBuilder: return self def with_time_attribute( - self, attribute: MedRecordAttribute + self, attribute: Optional[MedRecordAttribute] ) -> TreatmentEffectBuilder: """Sets the time attribute to be used in the treatment effect estimation. Args: - attribute (MedRecordAttribute): The time attribute. + attribute (Optional[MedRecordAttribute]): The time attribute. If None, + there is no temporal analysis, but only static one. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index 5aa89e0..fc05d1f 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -1,9 +1,10 @@ from math import sqrt -from typing import Literal, Set +from typing import Literal, Optional, Set import numpy as np from medmodels.medrecord.medrecord import MedRecord +from medmodels.medrecord.querying import EdgeOperand from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex from medmodels.treatment_effect.temporal_analysis import find_reference_edge @@ -15,7 +16,7 @@ def average_treatment_effect( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = "time", ) -> float: r"""Calculates the Average Treatment Effect (ATE) as the difference between the outcome means of the treated and control sets. @@ -50,8 +51,11 @@ def average_treatment_effect( exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". - time_attribute (MedRecordAttribute, optional): The attribute in the edge that - contains the time information. Defaults to "time". + time_attribute (Optional[MedRecordAttribute], optional): The attribute in the + edge that contains the time information. If it is equal to None, there is + no time component in the data and all edges between the sets and the + outcomes are considered for the average treatment effect. Defaults to + "time". Returns: float: The average treatment effect. @@ -59,37 +63,64 @@ def average_treatment_effect( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ] - ) + if time_attribute is not None: + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) + + else: + edges_treated_outcomes = medrecord.select_edges( + lambda edge: query_edges_set_outcome( + edge, treatment_outcome_true_set, outcome_group + ) + ) + treated_outcomes = treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_treated_outcomes + ] + ) + + edges_control_outcomes = medrecord.select_edges( + lambda edge: query_edges_set_outcome( + edge, control_outcome_true_set, outcome_group + ) + ) + control_outcomes = treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_control_outcomes + ] + ) + if not all(isinstance(i, (int, float)) for i in treated_outcomes): raise ValueError("Outcome variable must be numeric") - control_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute="time", - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ] - ) if not all(isinstance(i, (int, float)) for i in control_outcomes): raise ValueError("Outcome variable must be numeric") @@ -103,7 +134,7 @@ def cohens_d( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = "time", add_correction: bool = False, ) -> float: """Calculates Cohen's D, the standardized mean difference between two sets, measuring the effect size of the difference between two outcome means. @@ -141,8 +172,11 @@ def cohens_d( exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". - time_attribute (MedRecordAttribute, optional): The attribute in the edge that - contains the time information. Defaults to "time". + time_attribute (Optional[MedRecordAttribute], optional): The attribute in the + edge that contains the time information. If it is equal to None, there is + no time component in the data and all edges between the sets and the + outcomes are considered for the average treatment effect. Defaults to + "time". add_correction (bool, optional): Whether to apply a correction factor for small sample sizes. When True, using Hedges' g formula instead of Cohens' D. Defaults to False. @@ -153,37 +187,62 @@ def cohens_d( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ] - ) + if time_attribute is not None: + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) + else: + edges_treated_outcomes = medrecord.select_edges( + lambda edge: query_edges_set_outcome( + edge, treatment_outcome_true_set, outcome_group + ) + ) + treated_outcomes = treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_treated_outcomes + ] + ) + + edges_control_outcomes = medrecord.select_edges( + lambda edge: query_edges_set_outcome( + edge, control_outcome_true_set, outcome_group + ) + ) + control_outcomes = treated_outcomes = np.array( + [ + medrecord.edge[edge_id][outcome_variable] + for edge_id in edges_control_outcomes + ] + ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): raise ValueError("Outcome variable must be numeric") - control_outcomes = np.array( - [ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute="time", - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ] - ) if not all(isinstance(i, (int, float)) for i in control_outcomes): raise ValueError("Outcome variable must be numeric") @@ -207,3 +266,34 @@ def cohens_d( "average_te": average_treatment_effect, "cohens_d": cohens_d, } + + +def query_edges_set_outcome( + edge: EdgeOperand, set: Set[NodeIndex], outcomes_group: Group +): + """Query edges that connect a set of nodes to the outcomes group. + + Args: + edge (EdgeOperand): The edge operand to query. + set (Set[NodeIndex]): A set of node indices representing the treated group that + also have the outcome. + outcomes_group (Group): The group of nodes that contain the outcome variable. + """ + list_nodes = list(set) + edge.either_or( + lambda edge: edge.source_node().index().is_in(list_nodes), + lambda edge: edge.target_node().index().is_in(list_nodes), + ) + edge.either_or( + lambda edge: edge.source_node().index().is_in(list_nodes), + lambda edge: edge.target_node().index().is_in(list_nodes), + ) + + edge.either_or( + lambda edge: edge.source_node().in_group(outcomes_group), + lambda edge: edge.target_node().in_group(outcomes_group), + ) + edge.either_or( + lambda edge: edge.source_node().in_group(outcomes_group), + lambda edge: edge.target_node().in_group(outcomes_group), + ) diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 1d1d853..eaa45e8 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -304,7 +304,7 @@ def test_default_properties(self): TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .with_time_attribute("time") + .with_time_attribute(None) .with_patients_group("patients") .with_washout_period(reference="first") .with_grace_period(days=0, reference="last") @@ -379,6 +379,7 @@ def test_query_node_within_time_window(self): TreatmentEffect.builder() .with_outcome("Stroke") .with_treatment("Rivaroxaban") + .with_time_attribute("time") .build() ) treated_set = tee._find_treated_patients(self.medrecord) @@ -434,6 +435,25 @@ def test_find_groups(self): self.assertEqual(control_outcome_true, set({"P1", "P4", "P7"})) self.assertEqual(control_outcome_false, set({"P5", "P8", "P9"})) + # for this scenario, it works the same in temporal and static analysis + tee = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_time_attribute("time") + .build() + ) + ( + treatment_outcome_true, + treatment_outcome_false, + control_outcome_true, + control_outcome_false, + ) = tee._find_groups(self.medrecord) + self.assertEqual(treatment_outcome_true, set({"P2", "P3"})) + self.assertEqual(treatment_outcome_false, set({"P6"})) + self.assertEqual(control_outcome_true, set({"P1", "P4", "P7"})) + self.assertEqual(control_outcome_false, set({"P5", "P8", "P9"})) + def test_compute_subject_counts(self): tee = ( TreatmentEffect.builder() @@ -543,6 +563,7 @@ def test_follow_up_period(self): TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .with_follow_up_period(30) .build() ) @@ -559,6 +580,7 @@ def test_grace_period(self): .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_grace_period(10) + .with_time_attribute("time") .build() ) @@ -578,6 +600,7 @@ def test_invalid_grace_period(self): .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_grace_period(1000) + .with_time_attribute("time") .build() ) @@ -589,6 +612,7 @@ def test_washout_period(self): .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_washout_period(washout_dict) + .with_time_attribute("time") .build() ) @@ -610,6 +634,7 @@ def test_washout_period(self): .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_washout_period(washout_dict2) + .with_time_attribute("time") .build() ) @@ -629,6 +654,7 @@ def test_outcome_before_treatment(self): TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .build() ) treated_set = tee._find_treated_patients(self.medrecord) @@ -644,6 +670,7 @@ def test_outcome_before_treatment(self): TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .with_outcome_before_treatment_exclusion(30) .build() ) @@ -659,12 +686,13 @@ def test_outcome_before_treatment(self): self.assertEqual(outcome_before_treatment_nodes, set({"P3"})) # case 3 no outcome - self.medrecord.add_group("Headache") + self.medrecord.add_group("no_outcome") tee3 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") - .with_outcome("Headache") + .with_outcome("no_outcome") + .with_time_attribute("time") .with_outcome_before_treatment_exclusion(30) .build() ) @@ -675,14 +703,15 @@ def test_outcome_before_treatment(self): tee3._find_outcomes(medrecord=self.medrecord, treated_set=treated_set) def test_filter_controls(self): - def query1(node: NodeOperand): + def query_neighbors_to_m2(node: NodeOperand): node.neighbors(EdgeDirection.BOTH).index().equal_to("M2") tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(query1) + .with_time_attribute("time") + .filter_controls(query_neighbors_to_m2) .build() ) counts_tee = tee.estimate._compute_subject_counts(self.medrecord) @@ -690,15 +719,15 @@ def query1(node: NodeOperand): self.assertEqual(counts_tee, (2, 1, 1, 2)) # filter females only - - def query2(node: NodeOperand): + def query_female_patients(node: NodeOperand): node.attribute("gender").equal_to("female") tee2 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(query2) + .with_time_attribute("time") + .filter_controls(query_female_patients) .build() ) @@ -831,11 +860,38 @@ def test_full_report(self): self.assertDictEqual(report_test, full_report) def test_continuous_estimators_report(self): - """Test the continuous report of the TreatmentEffect class.""" + """Test the continuous report of the TreatmentEffect.""" + tee = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .build() + ) + + report_test = { + "average_treatment_effect": tee.estimate.average_treatment_effect( + self.medrecord, + outcome_variable="intensity", + ), + "cohens_d": tee.estimate.cohens_d( + self.medrecord, outcome_variable="intensity" + ), + } + + self.assertDictEqual( + report_test, + tee.report.continuous_estimators_report( + self.medrecord, outcome_variable="intensity" + ), + ) + + def test_continuous_estimators_report_with_time(self): + """Test the continuous report of the TreatmentEffect with time attribute.""" tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") + .with_time_attribute("time") .build() ) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 169dc7b..a2efef2 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -6,6 +6,10 @@ without undergoing the treatment. The class supports customizable criteria filtering, time constraints between treatment and outcome, and optional matching of control groups to treatment groups using a specified matching class. + +The default TreatmentEffect class performs an static analysis without considering time. +To perform a time-based analysis, users can specify a time attribute in the configuration +and set the washout period, grace period, and follow-up period. """ from __future__ import annotations @@ -36,7 +40,7 @@ class TreatmentEffect: _outcomes_group: Group _patients_group: Group - _time_attribute: MedRecordAttribute + _time_attribute: Optional[MedRecordAttribute] _washout_period_days: Dict[str, int] _washout_period_reference: Literal["first", "last"] @@ -83,7 +87,7 @@ def _set_configuration( treatment: Group, outcome: Group, patients_group: Group = "patients", - time_attribute: MedRecordAttribute = "time", + time_attribute: Optional[MedRecordAttribute] = None, washout_period_days: Dict[str, int] = dict(), washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, @@ -110,10 +114,12 @@ def _set_configuration( outcome (Group): The group of outcomes to analyze. patients_group (Group, optional): The group of patients to analyze. Defaults to "patients". - time_attribute (MedRecordAttribute, optional): The time attribute to use for - time-based analysis. Defaults to "time". + time_attribute (Optional[MedRecordAttribute], optional): The time + attribute. If None, the treatment effect analysis is performed in an + static way (without considering time). Defaults to None. washout_period_days (Dict[str, int], optional): The washout period in days - for each treatment group. Defaults to dict(). + for each treatment group. In the case of no time attribute, it is not + applied. Defaults to dict(). washout_period_reference (Literal["first", "last"], optional): The reference point for the washout period. Defaults to "first". grace_period_days (int, optional): The grace period in days after the @@ -198,7 +204,14 @@ def _find_groups( """ # Find patients that underwent the treatment treated_set = self._find_treated_patients(medrecord) - treated_set, washout_nodes = self._apply_washout_period(medrecord, treated_set) + + if self._time_attribute: + treated_set, washout_nodes = self._apply_washout_period( + medrecord, treated_set + ) + else: + washout_nodes = set() + treated_set, treated_outcome_true, outcome_before_treatment_nodes = ( self._find_outcomes(medrecord, treated_set) ) @@ -285,7 +298,7 @@ def _find_outcomes( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) - if outcome_before_treatment_days: + if outcome_before_treatment_days and self._time_attribute: outcome_before_treatment_nodes = set( medrecord.select_nodes( lambda node: self._query_node_within_time_window( @@ -306,18 +319,25 @@ def _find_outcomes( f"dropped due to outcome before treatment." ) - treated_outcome_true = set( - medrecord.select_nodes( - lambda node: self._query_node_within_time_window( - node, - treated_set, - self._outcomes_group, - self._grace_period_days, - self._follow_up_period_days, - self._follow_up_period_reference, + if self._time_attribute: + treated_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_node_within_time_window( + node, + treated_set, + self._outcomes_group, + self._grace_period_days, + self._follow_up_period_days, + self._follow_up_period_reference, + ) + ) + ) + else: + treated_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_set_outcome_true(node, treated_set) ) ) - ) return treated_set, treated_outcome_true, outcome_before_treatment_nodes @@ -426,18 +446,26 @@ def _find_controls( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) - def query(node: NodeOperand): - node.index().is_in(list(control_set)) - node.neighbors(edge_direction=EdgeDirection.BOTH).in_group( - self._outcomes_group - ) - # Finding the patients that had the outcome in the control group - control_outcome_true = set(medrecord.select_nodes(query)) + control_outcome_true = set( + medrecord.select_nodes( + lambda node: self._query_set_outcome_true(node, control_set) + ) + ) control_outcome_false = control_set - control_outcome_true return control_outcome_true, control_outcome_false + def _query_set_outcome_true(self, node: NodeOperand, set: Set[NodeIndex]): + """Query for nodes that are in the given set and have the outcome. + + Args: + node (NodeOperand): The node to query. + set (Set[NodeIndex]): The set of nodes to query. + """ + node.index().is_in(list(set)) + node.neighbors(edge_direction=EdgeDirection.BOTH).in_group(self._outcomes_group) + def _query_node_within_time_window( self, node: NodeOperand, @@ -468,8 +496,13 @@ def _query_node_within_time_window( end_days (int): The end of the time window in days relative to the reference event. reference (Literal["first", "last"]): The reference point for the time window. + + Raises: + ValueError: If the time attribute is not set. """ node.index().is_in(list(treated_set)) + if self._time_attribute is None: + raise ValueError("Time attribute is not set.") edges_to_treatment = node.edges() edges_to_treatment.attribute(self._time_attribute).is_datetime() From f997191e40bcd1324c7593eeb794682040e0ff63 Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Fri, 6 Dec 2024 09:57:02 +0100 Subject: [PATCH 2/3] fix: PR comments --- .../treatment_effect/continuous_estimators.py | 35 ++++------- .../tests/test_treatment_effect.py | 59 ++++++++++++++++--- .../treatment_effect/treatment_effect.py | 28 +++++++-- 3 files changed, 87 insertions(+), 35 deletions(-) diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index fc05d1f..564af60 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -16,7 +16,7 @@ def average_treatment_effect( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: Optional[MedRecordAttribute] = "time", + time_attribute: Optional[MedRecordAttribute] = None, ) -> float: r"""Calculates the Average Treatment Effect (ATE) as the difference between the outcome means of the treated and control sets. @@ -55,7 +55,7 @@ def average_treatment_effect( edge that contains the time information. If it is equal to None, there is no time component in the data and all edges between the sets and the outcomes are considered for the average treatment effect. Defaults to - "time". + None. Returns: float: The average treatment effect. @@ -95,11 +95,11 @@ def average_treatment_effect( else: edges_treated_outcomes = medrecord.select_edges( - lambda edge: query_edges_set_outcome( + lambda edge: query_edges_between_set_outcome( edge, treatment_outcome_true_set, outcome_group ) ) - treated_outcomes = treated_outcomes = np.array( + treated_outcomes = np.array( [ medrecord.edge[edge_id][outcome_variable] for edge_id in edges_treated_outcomes @@ -107,11 +107,11 @@ def average_treatment_effect( ) edges_control_outcomes = medrecord.select_edges( - lambda edge: query_edges_set_outcome( + lambda edge: query_edges_between_set_outcome( edge, control_outcome_true_set, outcome_group ) ) - control_outcomes = treated_outcomes = np.array( + control_outcomes = np.array( [ medrecord.edge[edge_id][outcome_variable] for edge_id in edges_control_outcomes @@ -134,7 +134,7 @@ def cohens_d( outcome_group: Group, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", - time_attribute: Optional[MedRecordAttribute] = "time", + time_attribute: Optional[MedRecordAttribute] = None, add_correction: bool = False, ) -> float: """Calculates Cohen's D, the standardized mean difference between two sets, measuring the effect size of the difference between two outcome means. @@ -175,8 +175,7 @@ def cohens_d( time_attribute (Optional[MedRecordAttribute], optional): The attribute in the edge that contains the time information. If it is equal to None, there is no time component in the data and all edges between the sets and the - outcomes are considered for the average treatment effect. Defaults to - "time". + outcomes are considered for the Cohen's D calculation. Defaults to None. add_correction (bool, optional): Whether to apply a correction factor for small sample sizes. When True, using Hedges' g formula instead of Cohens' D. Defaults to False. @@ -218,11 +217,11 @@ def cohens_d( ) else: edges_treated_outcomes = medrecord.select_edges( - lambda edge: query_edges_set_outcome( + lambda edge: query_edges_between_set_outcome( edge, treatment_outcome_true_set, outcome_group ) ) - treated_outcomes = treated_outcomes = np.array( + treated_outcomes = np.array( [ medrecord.edge[edge_id][outcome_variable] for edge_id in edges_treated_outcomes @@ -230,11 +229,11 @@ def cohens_d( ) edges_control_outcomes = medrecord.select_edges( - lambda edge: query_edges_set_outcome( + lambda edge: query_edges_between_set_outcome( edge, control_outcome_true_set, outcome_group ) ) - control_outcomes = treated_outcomes = np.array( + control_outcomes = np.array( [ medrecord.edge[edge_id][outcome_variable] for edge_id in edges_control_outcomes @@ -268,7 +267,7 @@ def cohens_d( } -def query_edges_set_outcome( +def query_edges_between_set_outcome( edge: EdgeOperand, set: Set[NodeIndex], outcomes_group: Group ): """Query edges that connect a set of nodes to the outcomes group. @@ -284,16 +283,8 @@ def query_edges_set_outcome( lambda edge: edge.source_node().index().is_in(list_nodes), lambda edge: edge.target_node().index().is_in(list_nodes), ) - edge.either_or( - lambda edge: edge.source_node().index().is_in(list_nodes), - lambda edge: edge.target_node().index().is_in(list_nodes), - ) edge.either_or( lambda edge: edge.source_node().in_group(outcomes_group), lambda edge: edge.target_node().in_group(outcomes_group), ) - edge.either_or( - lambda edge: edge.source_node().in_group(outcomes_group), - lambda edge: edge.target_node().in_group(outcomes_group), - ) diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index eaa45e8..6d66661 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -308,12 +308,44 @@ def test_default_properties(self): .with_patients_group("patients") .with_washout_period(reference="first") .with_grace_period(days=0, reference="last") - .with_follow_up_period(365, reference="last") + .with_follow_up_period(365000, reference="last") .build() ) assert_treatment_effects_equal(self, tee, tee_builder) + def test_time_warnings(self): + """Test the warnings raised by the TreatmentEffect class with no time attribute.""" + with self.assertLogs(level="WARNING") as log_capture: + _ = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_washout_period({"Warfarin": 30}) + .build() + ) + + self.assertIn( + "Washout period is not applied because the time attribute is not set.", + log_capture.output[0], + ) + + with self.assertLogs(level="WARNING") as log_capture: + _ = ( + TreatmentEffect.builder() + .with_treatment("Rivaroxaban") + .with_outcome("Stroke") + .with_follow_up_period(365) + .build() + ) + + self.assertIn( + "Time attribute is not set, thus the grace period, follow-up " + + "period, and outcome before treatment cannot be applied. The " + + "treatment effect analysis is performed in a static way.", + log_capture.output[0], + ) + def test_check_medrecord(self): tee = ( TreatmentEffect.builder() @@ -600,6 +632,7 @@ def test_invalid_grace_period(self): .with_treatment("Rivaroxaban") .with_outcome("Stroke") .with_grace_period(1000) + .with_follow_up_period(365) .with_time_attribute("time") .build() ) @@ -619,12 +652,17 @@ def test_washout_period(self): self.assertDictEqual(tee._washout_period_days, washout_dict) treated_set = tee._find_treated_patients(self.medrecord) - treated_set, washout_nodes = tee._apply_washout_period( - self.medrecord, treated_set - ) + with self.assertLogs(level="WARNING") as log_capture: + treated_set, washout_nodes = tee._apply_washout_period( + self.medrecord, treated_set + ) self.assertEqual(treated_set, set({"P3", "P6"})) self.assertEqual(washout_nodes, set({"P2"})) + self.assertIn( + "1 subject was dropped due to having a treatment in the washout period.", + log_capture.output[0], + ) # smaller washout period washout_dict2 = {"Warfarin": 10} @@ -678,13 +716,20 @@ def test_outcome_before_treatment(self): self.assertEqual(tee2._outcome_before_treatment_days, 30) treated_set = tee2._find_treated_patients(self.medrecord) - treated_set, treatment_outcome_true, outcome_before_treatment_nodes = ( - tee2._find_outcomes(self.medrecord, treated_set) - ) + with self.assertLogs(level="WARNING") as log_capture: + treated_set, treatment_outcome_true, outcome_before_treatment_nodes = ( + tee2._find_outcomes(self.medrecord, treated_set) + ) + self.assertEqual(treated_set, set({"P2", "P6"})) self.assertEqual(treatment_outcome_true, set({"P2"})) self.assertEqual(outcome_before_treatment_nodes, set({"P3"})) + self.assertIn( + "1 subject was dropped due to having an outcome before the treatment.", + log_capture.output[0], + ) + # case 3 no outcome self.medrecord.add_group("no_outcome") diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index a2efef2..6a230fb 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -92,7 +92,7 @@ def _set_configuration( washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, grace_period_reference: Literal["first", "last"] = "last", - follow_up_period_days: int = 365, + follow_up_period_days: int = 1000 * 365, follow_up_period_reference: Literal["first", "last"] = "last", outcome_before_treatment_days: Optional[int] = None, filter_controls_query: Optional[NodeQuery] = None, @@ -120,14 +120,14 @@ def _set_configuration( washout_period_days (Dict[str, int], optional): The washout period in days for each treatment group. In the case of no time attribute, it is not applied. Defaults to dict(). - washout_period_reference (Literal["first", "last"], optional): The reference - point for the washout period. Defaults to "first". + washout_period_reference (Literal["first", "last"], optional): The + reference point for the washout period. Defaults to "first". grace_period_days (int, optional): The grace period in days after the treatment. Defaults to 0. grace_period_reference (Literal["first", "last"], optional): The reference point for the grace period. Defaults to "last". follow_up_period_days (int, optional): The follow-up period in days after - the treatment. Defaults to 365. + the treatment. Defaults to 365000. follow_up_period_reference (Literal["first", "last"], optional): The reference point for the follow-up period. Defaults to "last". outcome_before_treatment_days (Optional[int], optional): The number of days @@ -180,6 +180,22 @@ def _set_configuration( treatment_effect._matching_number_of_neighbors = matching_number_of_neighbors treatment_effect._matching_hyperparam = matching_hyperparam + if washout_period_days and not time_attribute: + logging.warning( + "Washout period is not applied because the time attribute is not set." + ) + + if ( + grace_period_days + or (follow_up_period_days != 1000 * 365) + or outcome_before_treatment_days + ) and not time_attribute: + logging.warning( + "Time attribute is not set, thus the grace period, follow-up " + + "period, and outcome before treatment cannot be applied. The " + + "treatment effect analysis is performed in a static way." + ) + def _find_groups( self, medrecord: MedRecord ) -> Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: @@ -316,7 +332,7 @@ def _find_outcomes( dropped_num = len(outcome_before_treatment_nodes) logging.warning( f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " - f"dropped due to outcome before treatment." + f"dropped due to having an outcome before the treatment." ) if self._time_attribute: @@ -382,7 +398,7 @@ def _apply_washout_period( dropped_num = len(washout_nodes) logging.warning( f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " - f"dropped due to outcome before treatment." + f"dropped due to having a treatment in the washout period." ) return treated_set, washout_nodes From d8bc7338ba5f347ca04ef19742b85cd6dc750ed8 Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Wed, 15 Jan 2025 12:31:13 +0100 Subject: [PATCH 3/3] fix: comment on time_attribute --- medmodels/treatment_effect/builder.py | 11 ++++++++--- .../treatment_effect/tests/test_treatment_effect.py | 1 - 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index c015e02..27c2191 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -23,6 +23,9 @@ class TreatmentEffectBuilder: The TreatmentEffectBuilder class is used to build a TreatmentEffect object with the desired configurations for the treatment effect estimation using a builder pattern. + + By default, it configures a static treatment effect estimation. To configure a + time-dependent treatment effect estimation, the time_attribute must be set. """ treatment: Group @@ -93,13 +96,15 @@ def with_patients_group(self, group: Group) -> TreatmentEffectBuilder: return self def with_time_attribute( - self, attribute: Optional[MedRecordAttribute] + self, attribute: MedRecordAttribute ) -> TreatmentEffectBuilder: """Sets the time attribute to be used in the treatment effect estimation. + It turs the treatment effect estimation from a static to a time-dependent + analysis. + Args: - attribute (Optional[MedRecordAttribute]): The time attribute. If None, - there is no temporal analysis, but only static one. + attribute (MedRecordAttribute): The time attribute. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 25da99c..dca9372 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -304,7 +304,6 @@ def test_default_properties(self) -> None: TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .with_time_attribute(None) .with_patients_group("patients") .with_washout_period(reference="first") .with_grace_period(days=0, reference="last")