Skip to content

Commit

Permalink
fix: changing some incorrect variable naming and deleting how default…
Browse files Browse the repository at this point in the history
… values are assigned in essential and hot-encoded covariates for matching
  • Loading branch information
MarIniOnz committed Dec 17, 2024
1 parent b81e0c4 commit 866207a
Show file tree
Hide file tree
Showing 10 changed files with 448 additions and 94 deletions.
42 changes: 19 additions & 23 deletions medmodels/treatment_effect/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TreatmentEffectBuilder:
matching_one_hot_covariates: Optional[MedRecordAttributeInputList]
matching_model: Optional[Model]
matching_number_of_neighbors: Optional[int]
matching_hyperparam: Optional[Dict[str, Any]]
matching_hyperparameters: Optional[Dict[str, Any]]

def with_treatment(self, treatment: Group) -> TreatmentEffectBuilder:
"""Sets the treatment group for the treatment effect estimation.
Expand Down Expand Up @@ -218,27 +218,25 @@ def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder:

def with_propensity_matching(
self,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> TreatmentEffectBuilder:
"""Adjust the treatment effect estimate using propensity score matching.
Args:
essential_covariates (MedRecordAttributeInputList, optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (MedRecordAttributeInputList, optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
model (Model, optional): Model to choose for the matching. Defaults to
"logit".
number_of_neighbors (int, optional): Number of neighbors to consider
for the matching. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for the
matching model. Defaults to None.
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for
the matching model. Defaults to None.
Returns:
TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder
Expand All @@ -249,27 +247,25 @@ def with_propensity_matching(
self.matching_one_hot_covariates = one_hot_covariates
self.matching_model = model
self.matching_number_of_neighbors = number_of_neighbors
self.matching_hyperparam = hyperparam
self.matching_hyperparameters = hyperparameters

return self

def with_nearest_neighbors_matching(
self,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
number_of_neighbors: int = 1,
) -> TreatmentEffectBuilder:
"""Adjust the treatment effect estimate using nearest neighbors matching.
Args:
essential_covariates (MedRecordAttributeInputList, optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (MedRecordAttributeInputList, optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
number_of_neighbors (int, optional): Number of neighbors to consider for the
matching. Defaults to 1.
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
number_of_neighbors (int, optional): Number of neighbors to consider for
the matching. Defaults to 1.
Returns:
TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder
Expand Down
3 changes: 2 additions & 1 deletion medmodels/treatment_effect/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _sort_subjects_in_groups(
else PropensityMatching(
number_of_neighbors=self._treatment_effect._matching_number_of_neighbors,
model=self._treatment_effect._matching_model,
hyperparam=self._treatment_effect._matching_hyperparam,
hyperparameters=self._treatment_effect._matching_hyperparameters,
)
)

Expand All @@ -191,6 +191,7 @@ def _sort_subjects_in_groups(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
patients_group=self._treatment_effect._patients_group,
essential_covariates=self._treatment_effect._matching_essential_covariates,
one_hot_covariates=self._treatment_effect._matching_one_hot_covariates,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def calculate_propensity(
treated_test: NDArray[Union[np.int64, np.float64]],
control_test: NDArray[Union[np.int64, np.float64]],
model: Model = "logit",
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Trains a classification algorithm on training data, predicts the probability of being in the last class for treated and control test datasets, and returns these probabilities.
Expand All @@ -49,8 +49,8 @@ def calculate_propensity(
control group to predict probabilities.
model (Model, optional): Classification algorithm to use. Options: "logit",
"dec_tree", "forest".
hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings.
Uses default if None.
hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter
settings. Uses default hyperparameters if None.
Returns:
Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive
Expand All @@ -61,7 +61,7 @@ class for treated and control groups.
last class for treated and control sets, e.g., ([0.], [0.]).
"""
propensity_model = PROP_MODEL[model]
pm = propensity_model(**hyperparam) if hyperparam else propensity_model()
pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model()
pm.fit(x_train, y_train)

# Predict the probability of the treated and control groups
Expand All @@ -76,7 +76,7 @@ def run_propensity_score(
control_set: pl.DataFrame,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
covariates: Optional[MedRecordAttributeInputList] = None,
) -> pl.DataFrame:
"""Executes Propensity Score matching using a specified classification algorithm.
Expand All @@ -95,7 +95,7 @@ def run_propensity_score(
Options include "logit", "dec_tree", "forest".
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model
tuning. Increases computation time if set. Uses default if None.
covariates (Optional[MedRecordAttributeInputList], optional): Features for
matching. Uses all if None.
Expand All @@ -119,7 +119,7 @@ def run_propensity_score(
y_train,
treated_array,
control_array,
hyperparam=hyperparam,
hyperparameters=hyperparameters,
model=model,
)

Expand Down
129 changes: 117 additions & 12 deletions medmodels/treatment_effect/matching/matching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Literal, Set, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple

import polars as pl

from medmodels.medrecord._overview import extract_attribute_summary
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.querying import NodeOperand
from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex

if TYPE_CHECKING:
import sys
Expand All @@ -22,36 +24,68 @@
class Matching(metaclass=ABCMeta):
"""The Base Class for matching."""

number_of_neighbors: int

def __init__(self, number_of_neighbors: int) -> None:
"""Initializes the matching class.
Args:
number_of_neighbors (int): Number of nearest neighbors to find for each treated unit.
"""
self.number_of_neighbors = number_of_neighbors

def _preprocess_data(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
one_hot_covariates: MedRecordAttributeInputList,
patients_group: Group,
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Prepared the data for the matching algorithms.
Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of treated subjects.
treated_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates
that are essential for matching
one_hot_covariates (MedRecordAttributeInputList): Covariates that
are one-hot encoded for matching
patients_group (Group): The group of patients.
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their
preprocessed covariates
Raises:
ValueError: If not enough control subjects to match the treated subjects.
ValueError: If some treated nodes do not have all the essential covariates.
AssertionError: If the one-hot covariates are not in the essential covariates.
"""
essential_covariates = [str(covariate) for covariate in essential_covariates]
if essential_covariates is None:
# If no essential covariates are provided, use all the attributes of the patients
essential_covariates = list(
extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
)
else:
essential_covariates = [covariate for covariate in essential_covariates]

control_set = self._check_nodes(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
essential_covariates=essential_covariates,
)

if "id" not in essential_covariates:
essential_covariates.append("id")

# Dataframe
# Dataframe wth the essential covariates
data = pl.DataFrame(
data=[
{"id": k, **v}
Expand All @@ -60,6 +94,24 @@ def _preprocess_data(
)
original_columns = data.columns

if one_hot_covariates is None:
# If no one-hot covariates are provided, use all the categorical attributes of the patients
attributes = extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
one_hot_covariates = [
covariate
for covariate, values in attributes.items()
if "values" in values
]

if not all(
covariate in essential_covariates for covariate in one_hot_covariates
):
raise AssertionError(
"One-hot covariates must be in the essential covariates"
)

# One-hot encode the categorical variables
data = data.to_dummies(
columns=[str(covariate) for covariate in one_hot_covariates],
Expand All @@ -79,13 +131,66 @@ def _preprocess_data(

return data_treated, data_control

def _check_nodes(
self,
medrecord: MedRecord,
treated_set: Set[NodeIndex],
control_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
) -> Set[NodeIndex]:
"""Check if the treated and control sets are disjoint.
Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates that are
essential for matching.
Returns:
Set[NodeIndex]: The control set.
Raises:
ValueError: If not enough control subjects to match the treated subjects.
"""

def query_essential_covariates(
node: NodeOperand, patients_set: Set[NodeIndex]
) -> None:
"""Query the nodes that have all the essential covariates."""
for attribute in essential_covariates:
node.has_attribute(attribute)

node.index().is_in(list(patients_set))

control_set = set(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, control_set)
)
)
if len(control_set) < self.number_of_neighbors * len(treated_set):
raise ValueError(
"Not enough control subjects to match the treated subjects"
)

if len(treated_set) != len(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, treated_set)
)
):
raise ValueError(
"Some treated nodes do not have all the essential covariates"
)

return control_set

@abstractmethod
def match_controls(
self,
*,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
medrecord: MedRecord,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
) -> Set[NodeIndex]: ...
Loading

0 comments on commit 866207a

Please sign in to comment.