From c14182876477221c9d4b9b5a3108dd1aab77a554 Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Mon, 18 Nov 2024 17:50:52 +0100 Subject: [PATCH 1/5] fix: changing some incorrect variable naming and deleting how default values are assigned in essential and hot-encoded covariates for matching --- medmodels/treatment_effect/estimate.py | 3 +- .../matching/algorithms/propensity_score.py | 14 +- .../treatment_effect/matching/matching.py | 123 ++++++++- .../treatment_effect/matching/neighbors.py | 9 +- .../treatment_effect/matching/propensity.py | 22 +- .../matching/tests/test_matching.py | 247 ++++++++++++++++++ .../matching/tests/test_propensity_score.py | 33 ++- .../treatment_effect/treatment_effect.py | 4 +- 8 files changed, 411 insertions(+), 44 deletions(-) create mode 100644 medmodels/treatment_effect/matching/tests/test_matching.py diff --git a/medmodels/treatment_effect/estimate.py b/medmodels/treatment_effect/estimate.py index 0b6f7b6..f91a1a0 100644 --- a/medmodels/treatment_effect/estimate.py +++ b/medmodels/treatment_effect/estimate.py @@ -192,7 +192,7 @@ def _sort_subjects_in_groups( else PropensityMatching( number_of_neighbors=self._treatment_effect._matching_number_of_neighbors, model=self._treatment_effect._matching_model, - hyperparam=self._treatment_effect._matching_hyperparameters, + hyperparameters=self._treatment_effect._matching_hyperparameters, ) ) @@ -202,6 +202,7 @@ def _sort_subjects_in_groups( medrecord=medrecord, treated_set=treated_set, control_set=control_set, + patients_group=self._treatment_effect._patients_group, essential_covariates=self._treatment_effect._matching_essential_covariates, one_hot_covariates=self._treatment_effect._matching_one_hot_covariates, ) diff --git a/medmodels/treatment_effect/matching/algorithms/propensity_score.py b/medmodels/treatment_effect/matching/algorithms/propensity_score.py index 19c8271..c30034e 100644 --- a/medmodels/treatment_effect/matching/algorithms/propensity_score.py +++ b/medmodels/treatment_effect/matching/algorithms/propensity_score.py @@ -35,7 +35,7 @@ def calculate_propensity( treated_test: NDArray[Union[np.int64, np.float64]], control_test: NDArray[Union[np.int64, np.float64]], model: Model = "logit", - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ) -> Tuple[NDArray[np.float64], NDArray[np.float64]]: """Calculates the propensity/probabilities of a subject being in the treated group. @@ -55,8 +55,8 @@ def calculate_propensity( control group to predict probabilities. model (Model, optional): Classification algorithm to use. Options: "logit", "dec_tree", "forest". - hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings. - Uses default if None. + hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter + settings. Uses default hyperparameters if None. Returns: Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive @@ -67,7 +67,7 @@ class for treated and control groups. last class for treated and control sets, e.g., ([0.], [0.]). """ propensity_model = PROP_MODEL[model] - pm = propensity_model(**hyperparam) if hyperparam else propensity_model() + pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model() pm.fit(x_train, y_train) # Predict the probability of the treated and control groups @@ -82,7 +82,7 @@ def run_propensity_score( control_set: pl.DataFrame, model: Model = "logit", number_of_neighbors: int = 1, - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, covariates: Optional[MedRecordAttributeInputList] = None, ) -> pl.DataFrame: """Executes Propensity Score matching using a specified classification algorithm. @@ -101,7 +101,7 @@ def run_propensity_score( Options include "logit", "dec_tree", "forest". number_of_neighbors (int, optional): Number of nearest neighbors to find for each treated unit. Defaults to 1. - hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model + hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model tuning. Increases computation time if set. Uses default if None. covariates (Optional[MedRecordAttributeInputList], optional): Features for matching. Uses all if None. @@ -125,7 +125,7 @@ def run_propensity_score( y_train, treated_array, control_array, - hyperparam=hyperparam, + hyperparameters=hyperparameters, model=model, ) diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index e879076..f4f1c30 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -12,9 +12,13 @@ import polars as pl +from medmodels.medrecord._overview import extract_attribute_summary +from medmodels.medrecord.medrecord import MedRecord + if TYPE_CHECKING: from medmodels.medrecord.medrecord import MedRecord - from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.querying import NodeOperand + from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"] @@ -22,14 +26,26 @@ class Matching(ABC): """The Abstract Class for matching.""" + number_of_neighbors: int + + def __init__(self, number_of_neighbors: int) -> None: + """Initializes the matching class. + + Args: + number_of_neighbors (int): Number of nearest neighbors to find for each + treated patient. + """ + self.number_of_neighbors = number_of_neighbors + def _preprocess_data( self, *, medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList, - one_hot_covariates: MedRecordAttributeInputList, + patients_group: Group, + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """Prepared the data for the matching algorithms. @@ -37,21 +53,41 @@ def _preprocess_data( medrecord (MedRecord): MedRecord object containing the data. control_set (Set[NodeIndex]): Set of treated subjects. treated_set (Set[NodeIndex]): Set of control subjects. - essential_covariates (MedRecordAttributeInputList): Covariates - that are essential for matching - one_hot_covariates (MedRecordAttributeInputList): Covariates that - are one-hot encoded for matching + patients_group (Group): The group of patients. + essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. Returns: Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their preprocessed covariates + + Raises: + AssertionError: If the one-hot covariates are not in the essential + covariates. """ - essential_covariates = [str(covariate) for covariate in essential_covariates] + if essential_covariates is None: + # If no essential covariates provided, use all attributes of patients group + essential_covariates = list( + extract_attribute_summary( + medrecord.node[medrecord.nodes_in_group(patients_group)] + ) + ) + else: + essential_covariates = list(essential_covariates) + + control_set = self._check_nodes( + medrecord=medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=essential_covariates, + ) if "id" not in essential_covariates: essential_covariates.append("id") - # Dataframe + # Dataframe wth the essential covariates data = pl.DataFrame( data=[ {"id": k, **v} @@ -60,6 +96,23 @@ def _preprocess_data( ) original_columns = data.columns + # If no one-hot covariates provided, use all categorical attributes of patients + if one_hot_covariates is None: + attributes = extract_attribute_summary( + medrecord.node[medrecord.nodes_in_group(patients_group)] + ) + one_hot_covariates = [ + covariate + for covariate, values in attributes.items() + if "values" in values + ] + + if not all( + covariate in essential_covariates for covariate in one_hot_covariates + ): + msg = "One-hot covariates must be in the essential covariates" + raise AssertionError(msg) + # One-hot encode the categorical variables data = data.to_dummies( columns=[str(covariate) for covariate in one_hot_covariates], @@ -79,6 +132,58 @@ def _preprocess_data( return data_treated, data_control + def _check_nodes( + self, + medrecord: MedRecord, + treated_set: Set[NodeIndex], + control_set: Set[NodeIndex], + essential_covariates: MedRecordAttributeInputList, + ) -> Set[NodeIndex]: + """Check if the treated and control sets are disjoint. + + Args: + medrecord (MedRecord): MedRecord object containing the data. + treated_set (Set[NodeIndex]): Set of treated subjects. + control_set (Set[NodeIndex]): Set of control subjects. + essential_covariates (MedRecordAttributeInputList): Covariates that are + essential for matching. + + Returns: + Set[NodeIndex]: The control set. + + Raises: + ValueError: If not enough control subjects to match the treated subjects. + ValueError: If some treated nodes do not have all the essential covariates. + """ + + def query_essential_covariates( + node: NodeOperand, patients_set: Set[NodeIndex] + ) -> None: + """Query the nodes that have all the essential covariates.""" + for attribute in essential_covariates: + node.has_attribute(attribute) + + node.index().is_in(list(patients_set)) + + control_set = set( + medrecord.select_nodes( + lambda node: query_essential_covariates(node, control_set) + ) + ) + if len(control_set) < self.number_of_neighbors * len(treated_set): + msg = "Not enough control subjects to match the treated subjects" + raise ValueError(msg) + + if len(treated_set) != len( + medrecord.select_nodes( + lambda node: query_essential_covariates(node, treated_set) + ) + ): + msg = "Some treated nodes do not have all the essential covariates" + raise ValueError(msg) + + return control_set + @abstractmethod def match_controls( self, diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 8948bf0..8df2c0a 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from medmodels import MedRecord - from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex class NeighborsMatching(Matching): @@ -36,7 +36,7 @@ def __init__( number_of_neighbors (int, optional): Number of nearest neighbors to find for each treated unit. Defaults to 1. """ - self.number_of_neighbors = number_of_neighbors + super().__init__(number_of_neighbors) def match_controls( self, @@ -44,6 +44,7 @@ def match_controls( medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], + patients_group: Group, essential_covariates: Optional[MedRecordAttributeInputList] = None, one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Set[NodeIndex]: @@ -51,8 +52,9 @@ def match_controls( Args: medrecord (MedRecord): MedRecord object containing the data. - treated_set (Set[NodeIndex]): Set of treated subjects. control_set (Set[NodeIndex]): Set of control subjects. + treated_set (Set[NodeIndex]): Set of treated subjects. + patients_group (Group): Group of patients in the MedRecord. essential_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. @@ -72,6 +74,7 @@ def match_controls( medrecord=medrecord, control_set=control_set, treated_set=treated_set, + patients_group=patients_group, essential_covariates=essential_covariates, one_hot_covariates=one_hot_covariates, ) diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index ac74d58..c8b949a 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -7,6 +7,8 @@ import numpy as np import polars as pl +from medmodels import MedRecord +from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -18,7 +20,7 @@ if TYPE_CHECKING: from medmodels import MedRecord - from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex class PropensityMatching(Matching): @@ -40,23 +42,22 @@ def __init__( *, model: Model = "logit", number_of_neighbors: int = 1, - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ) -> None: """Initializes the propensity score class. Args: model (Model, optional): classification method to be used, default: "logit". Can be chosen from ["logit", "dec_tree", "forest"]. - nearest_neighbors_algorithm (NNAlgorithm, optional): algorithm used to - compute nearest neighbors. Defaults to "auto". number_of_neighbors (int, optional): number of neighbors to be matched per treated subject. Defaults to 1. - hyperparam (Optional[Dict[str, Any]], optional): hyperparameters for the - classification model, default: None. + hyperparameters (Optional[Dict[str, Any]], optional): hyperparameters for + the classification model. Defaults to None. """ + super().__init__(number_of_neighbors) self.model = model self.number_of_neighbors = number_of_neighbors - self.hyperparam = hyperparam + self.hyperparameters = hyperparameters def match_controls( self, @@ -64,6 +65,7 @@ def match_controls( medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], + patients_group: Group, essential_covariates: Optional[MedRecordAttributeInputList] = None, one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Set[NodeIndex]: @@ -71,8 +73,9 @@ def match_controls( Args: medrecord (MedRecord): medrecord object containing the data. - treated_set (Set[NodeIndex]): Set of treated subjects. control_set (Set[NodeIndex]): Set of control subjects. + treated_set (Set[NodeIndex]): Set of treated subjects. + patients_group (Group): Group of patients in MedRecord. essential_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. @@ -93,6 +96,7 @@ def match_controls( medrecord=medrecord, treated_set=treated_set, control_set=control_set, + patients_group=patients_group, essential_covariates=essential_covariates, one_hot_covariates=one_hot_covariates, ) @@ -116,7 +120,7 @@ def match_controls( y_train=y_train, treated_test=treated_array, control_test=control_array, - hyperparam=self.hyperparam, + hyperparameters=self.hyperparameters, model=self.model, ) diff --git a/medmodels/treatment_effect/matching/tests/test_matching.py b/medmodels/treatment_effect/matching/tests/test_matching.py new file mode 100644 index 0000000..10ac49b --- /dev/null +++ b/medmodels/treatment_effect/matching/tests/test_matching.py @@ -0,0 +1,247 @@ +"""Tests for the NeighborsMatching class in the matching module.""" + +from __future__ import annotations + +import unittest +from typing import TYPE_CHECKING, List, Optional, Set + +import pandas as pd +import pytest + +from medmodels import MedRecord +from medmodels.treatment_effect.matching.neighbors import NeighborsMatching + +if TYPE_CHECKING: + from medmodels.medrecord.types import NodeIndex + + +def create_patients(patients_list: List[NodeIndex]) -> pd.DataFrame: + """Creates a patients dataframe. + + Args: + patients_list (List[NodeIndex]): List of patients to include in the dataframe. + + Returns: + pd.DataFrame: A patients dataframe. + """ + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], + "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], + "gender": [ + "male", + "female", + "male", + "female", + "male", + "female", + "male", + "female", + "male", + ], + } + ) + + return patients.loc[patients["index"].isin(patients_list)] + + +def create_medrecord(patients_list: Optional[List[NodeIndex]] = None) -> MedRecord: + """Creates a MedRecord object. + + Args: + patients_list (Optional[List[NodeIndex]], optional): List of patients to include + in the MedRecord. Defaults to None. + + Returns: + MedRecord: A MedRecord object. + """ + if patients_list is None: + patients_list = [ + "P1", + "P2", + "P3", + "P4", + "P5", + "P6", + "P7", + "P8", + "P9", + ] + patients = create_patients(patients_list=patients_list) + medrecord = MedRecord.from_pandas(nodes=[(patients, "index")]) + medrecord.add_group(group="patients", nodes=patients["index"].to_list()) + return medrecord + + +class TestNeighborsMatching(unittest.TestCase): + """Class to test the NeighborsMatching class in the matching module.""" + + def setUp(self) -> None: + self.medrecord = create_medrecord() + + def test_preprocess_data(self) -> None: + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + data_treated, data_control = neighbors_matching._preprocess_data( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age", "gender"], + one_hot_covariates=["gender"], + ) + + # Assert that the treated and control dataframes have the correct columns + assert "age" in data_treated.columns + assert "age" in data_control.columns + assert ( + "gender_female" in data_treated.columns + or "gender_male" in data_treated.columns + ) + assert ( + "gender_female" in data_control.columns + or "gender_male" in data_control.columns + ) + + # Assert that the treated and control dataframes have the correct number of rows + assert len(data_treated) == len(treated_set) + assert len(data_control) == len(control_set) + + # Try automatic detection of attributes + data_treated, data_control = neighbors_matching._preprocess_data( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + ) + + # Assert that the treated and control dataframes have the correct columns + assert "age" in data_treated.columns + assert "age" in data_control.columns + assert ( + "gender_female" in data_treated.columns + or "gender_male" in data_treated.columns + ) + assert ( + "gender_female" in data_control.columns + or "gender_male" in data_control.columns + ) + + # Assert that the treated and control dataframes have the correct number of rows + assert len(data_treated) == len(treated_set) + assert len(data_control) == len(control_set) + + def test_match_controls(self) -> None: + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + matched_controls = neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age", "gender"], + one_hot_covariates=["gender"], + ) + + # Assert that the matched controls are a subset of the control set + assert matched_controls.issubset(control_set) + + # Assert that the correct number of controls were matched + assert len(matched_controls) == len(treated_set) + + # Assert it works equally if no covariates are given (automatically assigned) + matched_controls_no_covariates_specified = neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + ) + + assert matched_controls_no_covariates_specified.issubset(control_set) + assert len(matched_controls_no_covariates_specified) == len(treated_set) + + def test_check_nodes(self) -> None: + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6", "P8"} + + # Test valid case + valid_control_set = neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + assert valid_control_set == control_set + + def test_invalid_check_nodes(self) -> None: + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + # Test insufficient control subjects + with pytest.raises( + ValueError, + match="Not enough control subjects to match the treated subjects", + ): + neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set={"P1"}, + essential_covariates=["age", "gender"], + ) + + neighbors_matching_two_neighbors = NeighborsMatching(number_of_neighbors=2) + with pytest.raises( + ValueError, + match="Not enough control subjects to match the treated subjects", + ): + neighbors_matching_two_neighbors._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + + # Test missing essential covariates in treated set + with pytest.raises( + ValueError, + match="Some treated nodes do not have all the essential covariates", + ): + neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set={"P2", "P10"}, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + + def test_invalid_match_controls(self) -> None: + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + with pytest.raises( + AssertionError, + match="One-hot covariates must be in the essential covariates", + ): + neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age"], + one_hot_covariates=["gender"], + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/medmodels/treatment_effect/matching/tests/test_propensity_score.py b/medmodels/treatment_effect/matching/tests/test_propensity_score.py index b511c01..5963909 100644 --- a/medmodels/treatment_effect/matching/tests/test_propensity_score.py +++ b/medmodels/treatment_effect/matching/tests/test_propensity_score.py @@ -13,14 +13,18 @@ def test_calculate_propensity(self) -> None: x, y = load_iris(return_X_y=True) # Set random state by each propensity estimator: - hyperparam = {"random_state": 1} - hyperparam_logit = {"random_state": 1, "max_iter": 200} + hyperparameters = {"random_state": 1} + hyperparameters_logit = {"random_state": 1, "max_iter": 200} x = np.array(x) y = np.array(y) # Logistic Regression model: result_1, result_2 = ps.calculate_propensity( - x, y, np.array([x[0, :]]), np.array([x[1, :]]), hyperparam=hyperparam_logit + x, + y, + np.array([x[0, :]]), + np.array([x[1, :]]), + hyperparameters=hyperparameters_logit, ) assert result_1[0] == pytest.approx(1.4e-08, 9) assert result_2[0] == pytest.approx(3e-08, 9) @@ -32,7 +36,7 @@ def test_calculate_propensity(self) -> None: np.array([x[0, :]]), np.array([x[1, :]]), model="dec_tree", - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) assert result_1[0] == pytest.approx(0, 2) assert result_2[0] == pytest.approx(0, 2) @@ -44,15 +48,15 @@ def test_calculate_propensity(self) -> None: np.array([x[0, :]]), np.array([x[1, :]]), model="forest", - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) assert result_1[0] == pytest.approx(0, 2) assert result_2[0] == pytest.approx(0, 2) def test_run_propensity_score(self) -> None: # Set random state by each propensity estimator: - hyperparam = {"random_state": 1} - hyperparam_logit = {"random_state": 1, "max_iter": 200} + hyperparameters = {"random_state": 1} + hyperparameters_logit = {"random_state": 1, "max_iter": 200} ########################################### # 1D example @@ -62,21 +66,21 @@ def test_run_propensity_score(self) -> None: # logit model expected_logit = pl.DataFrame({"a": [1.0, 3.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, hyperparam=hyperparam_logit + treated_set, control_set, hyperparameters=hyperparameters_logit ) assert result_logit.equals(expected_logit) # dec_tree metric expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, model="dec_tree", hyperparam=hyperparam + treated_set, control_set, model="dec_tree", hyperparameters=hyperparameters ) assert result_logit.equals(expected_logit) # forest model expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, model="forest", hyperparam=hyperparam + treated_set, control_set, model="forest", hyperparameters=hyperparameters ) assert result_logit.equals(expected_logit) @@ -91,7 +95,10 @@ def test_run_propensity_score(self) -> None: # logit model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, covariates=covs, hyperparam=hyperparam_logit + treated_set, + control_set, + covariates=covs, + hyperparameters=hyperparameters_logit, ) assert result_logit.equals(expected_logit) @@ -102,7 +109,7 @@ def test_run_propensity_score(self) -> None: control_set, model="dec_tree", covariates=covs, - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) assert result_logit.equals(expected_logit) @@ -113,7 +120,7 @@ def test_run_propensity_score(self) -> None: control_set, model="forest", covariates=covs, - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) assert result_logit.equals(expected_logit) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index c6971a7..ebe74c3 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -68,8 +68,8 @@ class TreatmentEffect: _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] - _matching_essential_covariates: MedRecordAttributeInputList - _matching_one_hot_covariates: MedRecordAttributeInputList + _matching_essential_covariates: Optional[MedRecordAttributeInputList] + _matching_one_hot_covariates: Optional[MedRecordAttributeInputList] _matching_model: Model _matching_number_of_neighbors: int _matching_hyperparameters: Optional[Dict[str, Any]] From ede1a649cc6d4a82a46ad1b13dbf435af58f193e Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Tue, 17 Dec 2024 17:59:18 +0100 Subject: [PATCH 2/5] fix: fix PR comments --- .../treatment_effect/matching/matching.py | 41 +++++++++++++------ .../matching/tests/test_matching.py | 20 +++++---- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index f4f1c30..d74205c 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -55,9 +55,11 @@ def _preprocess_data( treated_set (Set[NodeIndex]): Set of control subjects. patients_group (Group): The group of patients. essential_covariates (Optional[MedRecordAttributeInputList]): - Covariates that are essential for matching. Defaults to None. + Covariates that are essential for matching. Defaults to None, meaning + all the attributes of the patients are used. one_hot_covariates (Optional[MedRecordAttributeInputList]): - Covariates that are one-hot encoded for matching. Defaults to None. + Covariates that are one-hot encoded for matching. Defaults to None, + meaning all the categorical attributes of the patients are used. Returns: Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their @@ -74,8 +76,6 @@ def _preprocess_data( medrecord.node[medrecord.nodes_in_group(patients_group)] ) ) - else: - essential_covariates = list(essential_covariates) control_set = self._check_nodes( medrecord=medrecord, @@ -85,9 +85,9 @@ def _preprocess_data( ) if "id" not in essential_covariates: - essential_covariates.append("id") + essential_covariates.append("id") # pyright: ignore[reportArgumentType] - # Dataframe wth the essential covariates + # Dataframe with the essential covariates data = pl.DataFrame( data=[ {"id": k, **v} @@ -104,11 +104,21 @@ def _preprocess_data( one_hot_covariates = [ covariate for covariate, values in attributes.items() - if "values" in values + if "Categorical" in values["type"] ] - if not all( - covariate in essential_covariates for covariate in one_hot_covariates + one_hot_covariates = [ + covariate + for covariate in one_hot_covariates + if covariate in essential_covariates + ] + + # If there are one-hot covariates, check if all are in the essential covariates + if ( + not all( + covariate in essential_covariates for covariate in one_hot_covariates + ) + and one_hot_covariates ): msg = "One-hot covariates must be in the essential covariates" raise AssertionError(msg) @@ -122,8 +132,8 @@ def _preprocess_data( # Add to essential covariates the new columns created by one-hot encoding and # delete the ones that were one-hot encoded - essential_covariates.extend(new_columns) - [essential_covariates.remove(col) for col in one_hot_covariates] + essential_covariates.extend(new_columns) # pyright: ignore[reportArgumentType] + [essential_covariates.remove(col) for col in one_hot_covariates] # pyright: ignore[reportArgumentType] data = data.select(essential_covariates) # Select the sets of treated and control subjects @@ -170,8 +180,15 @@ def query_essential_covariates( lambda node: query_essential_covariates(node, control_set) ) ) + if len(control_set) < self.number_of_neighbors * len(treated_set): - msg = "Not enough control subjects to match the treated subjects" + msg = ( + f"Not enough control subjects to match the treated subjects. " + f"Number of controls: {len(control_set)}, " + f"Number of treated subjects: {len(treated_set)}, " + f"Number of neighbors required per treated subject: {self.number_of_neighbors}, " + f"Total controls needed: {self.number_of_neighbors * len(treated_set)}." + ) raise ValueError(msg) if len(treated_set) != len( diff --git a/medmodels/treatment_effect/matching/tests/test_matching.py b/medmodels/treatment_effect/matching/tests/test_matching.py index 10ac49b..4e6b95c 100644 --- a/medmodels/treatment_effect/matching/tests/test_matching.py +++ b/medmodels/treatment_effect/matching/tests/test_matching.py @@ -30,12 +30,12 @@ def create_patients(patients_list: List[NodeIndex]) -> pd.DataFrame: "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], "gender": [ "male", - "female", "male", "female", - "male", "female", "male", + "male", + "female", "female", "male", ], @@ -70,6 +70,7 @@ def create_medrecord(patients_list: Optional[List[NodeIndex]] = None) -> MedReco patients = create_patients(patients_list=patients_list) medrecord = MedRecord.from_pandas(nodes=[(patients, "index")]) medrecord.add_group(group="patients", nodes=patients["index"].to_list()) + medrecord.add_nodes(("P10", {}), "patients") return medrecord @@ -155,7 +156,7 @@ def test_match_controls(self) -> None: # Assert that the correct number of controls were matched assert len(matched_controls) == len(treated_set) - # Assert it works equally if no covariates are given (automatically assigned) + # It should do the same if no covariates are given (all attributes assigned) matched_controls_no_covariates_specified = neighbors_matching.match_controls( medrecord=self.medrecord, control_set=control_set, @@ -163,8 +164,7 @@ def test_match_controls(self) -> None: patients_group="patients", ) - assert matched_controls_no_covariates_specified.issubset(control_set) - assert len(matched_controls_no_covariates_specified) == len(treated_set) + assert matched_controls_no_covariates_specified == matched_controls def test_check_nodes(self) -> None: neighbors_matching = NeighborsMatching(number_of_neighbors=1) @@ -190,7 +190,10 @@ def test_invalid_check_nodes(self) -> None: # Test insufficient control subjects with pytest.raises( ValueError, - match="Not enough control subjects to match the treated subjects", + match="Not enough control subjects to match the treated subjects. " + + "Number of controls: 1, Number of treated subjects: 3, " + + "Number of neighbors required per treated subject: 1, " + + "Total controls needed: 3.", ): neighbors_matching._check_nodes( medrecord=self.medrecord, @@ -202,7 +205,10 @@ def test_invalid_check_nodes(self) -> None: neighbors_matching_two_neighbors = NeighborsMatching(number_of_neighbors=2) with pytest.raises( ValueError, - match="Not enough control subjects to match the treated subjects", + match="Not enough control subjects to match the treated subjects. " + + "Number of controls: 5, Number of treated subjects: 3, " + + "Number of neighbors required per treated subject: 2, " + + "Total controls needed: 6.", ): neighbors_matching_two_neighbors._check_nodes( medrecord=self.medrecord, From 00a562960b9d8041e5f84f3e5ce718d4a8008cca Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Tue, 14 Jan 2025 16:28:55 +0100 Subject: [PATCH 3/5] fix: PR comments --- .../treatment_effect/matching/matching.py | 47 +++++++++++-------- .../treatment_effect/matching/neighbors.py | 16 +++---- .../treatment_effect/matching/propensity.py | 17 ++++--- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index d74205c..194d83f 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -8,7 +8,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple, TypeAlias +from typing import ( + TYPE_CHECKING, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + TypeAlias, +) import polars as pl @@ -18,7 +27,7 @@ if TYPE_CHECKING: from medmodels.medrecord.medrecord import MedRecord from medmodels.medrecord.querying import NodeOperand - from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"] @@ -44,8 +53,8 @@ def _preprocess_data( control_set: Set[NodeIndex], treated_set: Set[NodeIndex], patients_group: Group, - essential_covariates: Optional[MedRecordAttributeInputList] = None, - one_hot_covariates: Optional[MedRecordAttributeInputList] = None, + essential_covariates: Optional[List[MedRecordAttribute]] = None, + one_hot_covariates: Optional[List[MedRecordAttribute]] = None, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """Prepared the data for the matching algorithms. @@ -54,10 +63,10 @@ def _preprocess_data( control_set (Set[NodeIndex]): Set of treated subjects. treated_set (Set[NodeIndex]): Set of control subjects. patients_group (Group): The group of patients. - essential_covariates (Optional[MedRecordAttributeInputList]): + essential_covariates (Optional[List[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to None, meaning all the attributes of the patients are used. - one_hot_covariates (Optional[MedRecordAttributeInputList]): + one_hot_covariates (Optional[List[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to None, meaning all the categorical attributes of the patients are used. @@ -71,10 +80,9 @@ def _preprocess_data( """ if essential_covariates is None: # If no essential covariates provided, use all attributes of patients group + nodes_attributes = medrecord.node[medrecord.nodes_in_group(patients_group)] essential_covariates = list( - extract_attribute_summary( - medrecord.node[medrecord.nodes_in_group(patients_group)] - ) + {key for attributes in nodes_attributes.values() for key in attributes} ) control_set = self._check_nodes( @@ -85,7 +93,7 @@ def _preprocess_data( ) if "id" not in essential_covariates: - essential_covariates.append("id") # pyright: ignore[reportArgumentType] + essential_covariates.append("id") # Dataframe with the essential covariates data = pl.DataFrame( @@ -132,8 +140,8 @@ def _preprocess_data( # Add to essential covariates the new columns created by one-hot encoding and # delete the ones that were one-hot encoded - essential_covariates.extend(new_columns) # pyright: ignore[reportArgumentType] - [essential_covariates.remove(col) for col in one_hot_covariates] # pyright: ignore[reportArgumentType] + essential_covariates.extend(new_columns) + [essential_covariates.remove(col) for col in one_hot_covariates] data = data.select(essential_covariates) # Select the sets of treated and control subjects @@ -147,7 +155,7 @@ def _check_nodes( medrecord: MedRecord, treated_set: Set[NodeIndex], control_set: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList, + essential_covariates: List[MedRecordAttribute], ) -> Set[NodeIndex]: """Check if the treated and control sets are disjoint. @@ -155,7 +163,7 @@ def _check_nodes( medrecord (MedRecord): MedRecord object containing the data. treated_set (Set[NodeIndex]): Set of treated subjects. control_set (Set[NodeIndex]): Set of control subjects. - essential_covariates (MedRecordAttributeInputList): Covariates that are + essential_covariates (List[MedRecordAttribute]): Covariates that are essential for matching. Returns: @@ -170,8 +178,7 @@ def query_essential_covariates( node: NodeOperand, patients_set: Set[NodeIndex] ) -> None: """Query the nodes that have all the essential covariates.""" - for attribute in essential_covariates: - node.has_attribute(attribute) + node.has_attribute(essential_covariates) node.index().is_in(list(patients_set)) @@ -208,8 +215,8 @@ def match_controls( medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], - essential_covariates: Optional[MedRecordAttributeInputList] = None, - one_hot_covariates: Optional[MedRecordAttributeInputList] = None, + essential_covariates: Optional[Sequence[MedRecordAttribute]] = None, + one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None, ) -> Set[NodeIndex]: """Matches the controls based on the matching algorithm. @@ -217,9 +224,9 @@ def match_controls( medrecord (MedRecord): MedRecord object containing the data. control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. - essential_covariates (Optional[MedRecordAttributeInputList], optional): + essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to None. - one_hot_covariates (Optional[MedRecordAttributeInputList], optional): + one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to None. Returns: diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 8df2c0a..4ac1f45 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Set +from typing import TYPE_CHECKING, Optional, Sequence, Set from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, @@ -11,7 +11,7 @@ if TYPE_CHECKING: from medmodels import MedRecord - from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex class NeighborsMatching(Matching): @@ -45,8 +45,8 @@ def match_controls( control_set: Set[NodeIndex], treated_set: Set[NodeIndex], patients_group: Group, - essential_covariates: Optional[MedRecordAttributeInputList] = None, - one_hot_covariates: Optional[MedRecordAttributeInputList] = None, + essential_covariates: Optional[Sequence[MedRecordAttribute]] = None, + one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None, ) -> Set[NodeIndex]: """Matches the controls based on the nearest neighbor algorithm. @@ -55,10 +55,10 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in the MedRecord. - essential_covariates (Optional[MedRecordAttributeInputList], optional): + essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. - one_hot_covariates (Optional[MedRecordAttributeInputList], optional): + one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to ["gender"]. @@ -75,8 +75,8 @@ def match_controls( control_set=control_set, treated_set=treated_set, patients_group=patients_group, - essential_covariates=essential_covariates, - one_hot_covariates=one_hot_covariates, + essential_covariates=list(essential_covariates), + one_hot_covariates=list(one_hot_covariates), ) # Run the algorithm to find the matched controls diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index c8b949a..0c2ca92 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -2,13 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set import numpy as np import polars as pl from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -20,7 +19,7 @@ if TYPE_CHECKING: from medmodels import MedRecord - from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex + from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex class PropensityMatching(Matching): @@ -66,8 +65,8 @@ def match_controls( control_set: Set[NodeIndex], treated_set: Set[NodeIndex], patients_group: Group, - essential_covariates: Optional[MedRecordAttributeInputList] = None, - one_hot_covariates: Optional[MedRecordAttributeInputList] = None, + essential_covariates: Optional[Sequence[MedRecordAttribute]] = None, + one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None, ) -> Set[NodeIndex]: """Matches the controls based on propensity score matching. @@ -76,10 +75,10 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in MedRecord. - essential_covariates (Optional[MedRecordAttributeInputList], optional): + essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. - one_hot_covariates (Optional[MedRecordAttributeInputList], optional): + one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to ["gender"]. @@ -97,8 +96,8 @@ def match_controls( treated_set=treated_set, control_set=control_set, patients_group=patients_group, - essential_covariates=essential_covariates, - one_hot_covariates=one_hot_covariates, + essential_covariates=list(essential_covariates), + one_hot_covariates=list(one_hot_covariates), ) # Convert the Polars DataFrames to NumPy arrays From 9ef8d28d665311dfb6a48881eeb4e9dfcb29e3ba Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Fri, 17 Jan 2025 14:03:51 +0100 Subject: [PATCH 4/5] fix: PR comments --- .../treatment_effect/matching/neighbors.py | 23 ++++++++----------- .../treatment_effect/matching/propensity.py | 23 ++++++++----------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 4ac1f45..34c87f4 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -55,28 +55,25 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in the MedRecord. - essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): - Covariates that are essential for matching. Defaults to - ["gender", "age"]. - one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): - Covariates that are one-hot encoded for matching. Defaults to - ["gender"]. + essential_covariates (Optional[List[MedRecordAttribute]], optional): + Covariates that are essential for matching. Defaults to None, meaning + all the attributes of the patients are used. + one_hot_covariates (Optional[List[MedRecordAttribute]], optional): + Covariates that are one-hot encoded for matching. Defaults to None, + meaning all the categorical attributes of the patients are used. Returns: Set[NodeIndex]: Node Ids of the matched controls. """ - if essential_covariates is None: - essential_covariates = ["gender", "age"] - if one_hot_covariates is None: - one_hot_covariates = ["gender"] - data_treated, data_control = self._preprocess_data( medrecord=medrecord, control_set=control_set, treated_set=treated_set, patients_group=patients_group, - essential_covariates=list(essential_covariates), - one_hot_covariates=list(one_hot_covariates), + essential_covariates=list(essential_covariates) + if essential_covariates + else None, + one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None, ) # Run the algorithm to find the matched controls diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index 0c2ca92..3e327d1 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -75,29 +75,26 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in MedRecord. - essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): - Covariates that are essential for matching. Defaults to - ["gender", "age"]. - one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): - Covariates that are one-hot encoded for matching. Defaults to - ["gender"]. + essential_covariates (Optional[List[MedRecordAttribute]], optional): + Covariates that are essential for matching. Defaults to None, meaning + all the attributes of the patients are used. + one_hot_covariates (Optional[List[MedRecordAttribute]], optional): + Covariates that are one-hot encoded for matching. Defaults to None, + meaning all the categorical attributes of the patients are used. Returns: Set[NodeIndex]: Node Ids of the matched controls. """ - if essential_covariates is None: - essential_covariates = ["gender", "age"] - if one_hot_covariates is None: - one_hot_covariates = ["gender"] - # Preprocess the data data_treated, data_control = self._preprocess_data( medrecord=medrecord, treated_set=treated_set, control_set=control_set, patients_group=patients_group, - essential_covariates=list(essential_covariates), - one_hot_covariates=list(one_hot_covariates), + essential_covariates=list(essential_covariates) + if essential_covariates + else None, + one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None, ) # Convert the Polars DataFrames to NumPy arrays From 2a405db1d8574bc5145548ad2a669fac2241294c Mon Sep 17 00:00:00 2001 From: MarIniOnz Date: Tue, 21 Jan 2025 11:31:22 +0100 Subject: [PATCH 5/5] fix: PR comments --- medmodels/treatment_effect/matching/neighbors.py | 4 ++-- medmodels/treatment_effect/matching/propensity.py | 4 ++-- medmodels/treatment_effect/treatment_effect.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 34c87f4..69ac821 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -55,10 +55,10 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in the MedRecord. - essential_covariates (Optional[List[MedRecordAttribute]], optional): + essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to None, meaning all the attributes of the patients are used. - one_hot_covariates (Optional[List[MedRecordAttribute]], optional): + one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to None, meaning all the categorical attributes of the patients are used. diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index 3e327d1..8782648 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -75,10 +75,10 @@ def match_controls( control_set (Set[NodeIndex]): Set of control subjects. treated_set (Set[NodeIndex]): Set of treated subjects. patients_group (Group): Group of patients in MedRecord. - essential_covariates (Optional[List[MedRecordAttribute]], optional): + essential_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are essential for matching. Defaults to None, meaning all the attributes of the patients are used. - one_hot_covariates (Optional[List[MedRecordAttribute]], optional): + one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional): Covariates that are one-hot encoded for matching. Defaults to None, meaning all the categorical attributes of the patients are used. diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index ebe74c3..c6971a7 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -68,8 +68,8 @@ class TreatmentEffect: _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] - _matching_essential_covariates: Optional[MedRecordAttributeInputList] - _matching_one_hot_covariates: Optional[MedRecordAttributeInputList] + _matching_essential_covariates: MedRecordAttributeInputList + _matching_one_hot_covariates: MedRecordAttributeInputList _matching_model: Model _matching_number_of_neighbors: int _matching_hyperparameters: Optional[Dict[str, Any]]