Skip to content

Commit

Permalink
fix: the way essential covariates and one-hot-encoded ones are proces…
Browse files Browse the repository at this point in the history
…sed (#256)
  • Loading branch information
MarIniOnz authored Jan 21, 2025
1 parent d56ee4f commit 2e4ee97
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 79 deletions.
3 changes: 2 additions & 1 deletion medmodels/treatment_effect/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _sort_subjects_in_groups(
else PropensityMatching(
number_of_neighbors=self._treatment_effect._matching_number_of_neighbors,
model=self._treatment_effect._matching_model,
hyperparam=self._treatment_effect._matching_hyperparameters,
hyperparameters=self._treatment_effect._matching_hyperparameters,
)
)

Expand All @@ -202,6 +202,7 @@ def _sort_subjects_in_groups(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
patients_group=self._treatment_effect._patients_group,
essential_covariates=self._treatment_effect._matching_essential_covariates,
one_hot_covariates=self._treatment_effect._matching_one_hot_covariates,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def calculate_propensity(
treated_test: NDArray[Union[np.int64, np.float64]],
control_test: NDArray[Union[np.int64, np.float64]],
model: Model = "logit",
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Calculates the propensity/probabilities of a subject being in the treated group.
Expand All @@ -55,8 +55,8 @@ def calculate_propensity(
control group to predict probabilities.
model (Model, optional): Classification algorithm to use. Options: "logit",
"dec_tree", "forest".
hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings.
Uses default if None.
hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter
settings. Uses default hyperparameters if None.
Returns:
Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive
Expand All @@ -67,7 +67,7 @@ class for treated and control groups.
last class for treated and control sets, e.g., ([0.], [0.]).
"""
propensity_model = PROP_MODEL[model]
pm = propensity_model(**hyperparam) if hyperparam else propensity_model()
pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model()
pm.fit(x_train, y_train)

# Predict the probability of the treated and control groups
Expand All @@ -82,7 +82,7 @@ def run_propensity_score(
control_set: pl.DataFrame,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
covariates: Optional[MedRecordAttributeInputList] = None,
) -> pl.DataFrame:
"""Executes Propensity Score matching using a specified classification algorithm.
Expand All @@ -101,7 +101,7 @@ def run_propensity_score(
Options include "logit", "dec_tree", "forest".
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model
tuning. Increases computation time if set. Uses default if None.
covariates (Optional[MedRecordAttributeInputList], optional): Features for
matching. Uses all if None.
Expand All @@ -125,7 +125,7 @@ def run_propensity_score(
y_train,
treated_array,
control_array,
hyperparam=hyperparam,
hyperparameters=hyperparameters,
model=model,
)

Expand Down
157 changes: 143 additions & 14 deletions medmodels/treatment_effect/matching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,94 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple, TypeAlias
from typing import (
TYPE_CHECKING,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeAlias,
)

import polars as pl

from medmodels.medrecord._overview import extract_attribute_summary
from medmodels.medrecord.medrecord import MedRecord

if TYPE_CHECKING:
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.querying import NodeOperand
from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex

MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"]


class Matching(ABC):
"""The Abstract Class for matching."""

number_of_neighbors: int

def __init__(self, number_of_neighbors: int) -> None:
"""Initializes the matching class.
Args:
number_of_neighbors (int): Number of nearest neighbors to find for each
treated patient.
"""
self.number_of_neighbors = number_of_neighbors

def _preprocess_data(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
one_hot_covariates: MedRecordAttributeInputList,
patients_group: Group,
essential_covariates: Optional[List[MedRecordAttribute]] = None,
one_hot_covariates: Optional[List[MedRecordAttribute]] = None,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Prepared the data for the matching algorithms.
Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of treated subjects.
treated_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates
that are essential for matching
one_hot_covariates (MedRecordAttributeInputList): Covariates that
are one-hot encoded for matching
patients_group (Group): The group of patients.
essential_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.
Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their
preprocessed covariates
Raises:
AssertionError: If the one-hot covariates are not in the essential
covariates.
"""
essential_covariates = [str(covariate) for covariate in essential_covariates]
if essential_covariates is None:
# If no essential covariates provided, use all attributes of patients group
nodes_attributes = medrecord.node[medrecord.nodes_in_group(patients_group)]
essential_covariates = list(
{key for attributes in nodes_attributes.values() for key in attributes}
)

control_set = self._check_nodes(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
essential_covariates=essential_covariates,
)

if "id" not in essential_covariates:
essential_covariates.append("id")

# Dataframe
# Dataframe with the essential covariates
data = pl.DataFrame(
data=[
{"id": k, **v}
Expand All @@ -60,6 +104,33 @@ def _preprocess_data(
)
original_columns = data.columns

# If no one-hot covariates provided, use all categorical attributes of patients
if one_hot_covariates is None:
attributes = extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
one_hot_covariates = [
covariate
for covariate, values in attributes.items()
if "Categorical" in values["type"]
]

one_hot_covariates = [
covariate
for covariate in one_hot_covariates
if covariate in essential_covariates
]

# If there are one-hot covariates, check if all are in the essential covariates
if (
not all(
covariate in essential_covariates for covariate in one_hot_covariates
)
and one_hot_covariates
):
msg = "One-hot covariates must be in the essential covariates"
raise AssertionError(msg)

# One-hot encode the categorical variables
data = data.to_dummies(
columns=[str(covariate) for covariate in one_hot_covariates],
Expand All @@ -79,25 +150,83 @@ def _preprocess_data(

return data_treated, data_control

def _check_nodes(
self,
medrecord: MedRecord,
treated_set: Set[NodeIndex],
control_set: Set[NodeIndex],
essential_covariates: List[MedRecordAttribute],
) -> Set[NodeIndex]:
"""Check if the treated and control sets are disjoint.
Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (List[MedRecordAttribute]): Covariates that are
essential for matching.
Returns:
Set[NodeIndex]: The control set.
Raises:
ValueError: If not enough control subjects to match the treated subjects.
ValueError: If some treated nodes do not have all the essential covariates.
"""

def query_essential_covariates(
node: NodeOperand, patients_set: Set[NodeIndex]
) -> None:
"""Query the nodes that have all the essential covariates."""
node.has_attribute(essential_covariates)

node.index().is_in(list(patients_set))

control_set = set(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, control_set)
)
)

if len(control_set) < self.number_of_neighbors * len(treated_set):
msg = (
f"Not enough control subjects to match the treated subjects. "
f"Number of controls: {len(control_set)}, "
f"Number of treated subjects: {len(treated_set)}, "
f"Number of neighbors required per treated subject: {self.number_of_neighbors}, "
f"Total controls needed: {self.number_of_neighbors * len(treated_set)}."
)
raise ValueError(msg)

if len(treated_set) != len(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, treated_set)
)
):
msg = "Some treated nodes do not have all the essential covariates"
raise ValueError(msg)

return control_set

@abstractmethod
def match_controls(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
essential_covariates: Optional[Sequence[MedRecordAttribute]] = None,
one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None,
) -> Set[NodeIndex]:
"""Matches the controls based on the matching algorithm.
Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of control subjects.
treated_set (Set[NodeIndex]): Set of treated subjects.
essential_covariates (Optional[MedRecordAttributeInputList], optional):
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList], optional):
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None.
Returns:
Expand Down
38 changes: 19 additions & 19 deletions medmodels/treatment_effect/matching/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Optional, Sequence, Set

from medmodels.treatment_effect.matching.algorithms.classic_distance_models import (
nearest_neighbor,
Expand All @@ -11,7 +11,7 @@

if TYPE_CHECKING:
from medmodels import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex


class NeighborsMatching(Matching):
Expand All @@ -36,44 +36,44 @@ def __init__(
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
"""
self.number_of_neighbors = number_of_neighbors
super().__init__(number_of_neighbors)

def match_controls(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
patients_group: Group,
essential_covariates: Optional[Sequence[MedRecordAttribute]] = None,
one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None,
) -> Set[NodeIndex]:
"""Matches the controls based on the nearest neighbor algorithm.
Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (Optional[MedRecordAttributeInputList], optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (Optional[MedRecordAttributeInputList], optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
treated_set (Set[NodeIndex]): Set of treated subjects.
patients_group (Group): Group of patients in the MedRecord.
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.
Returns:
Set[NodeIndex]: Node Ids of the matched controls.
"""
if essential_covariates is None:
essential_covariates = ["gender", "age"]
if one_hot_covariates is None:
one_hot_covariates = ["gender"]

data_treated, data_control = self._preprocess_data(
medrecord=medrecord,
control_set=control_set,
treated_set=treated_set,
essential_covariates=essential_covariates,
one_hot_covariates=one_hot_covariates,
patients_group=patients_group,
essential_covariates=list(essential_covariates)
if essential_covariates
else None,
one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None,
)

# Run the algorithm to find the matched controls
Expand Down
Loading

0 comments on commit 2e4ee97

Please sign in to comment.