Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: the way essential covariates and one-hot-encoded ones are processed #256

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion medmodels/treatment_effect/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _sort_subjects_in_groups(
else PropensityMatching(
number_of_neighbors=self._treatment_effect._matching_number_of_neighbors,
model=self._treatment_effect._matching_model,
hyperparam=self._treatment_effect._matching_hyperparameters,
hyperparameters=self._treatment_effect._matching_hyperparameters,
)
)

Expand All @@ -202,6 +202,7 @@ def _sort_subjects_in_groups(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
patients_group=self._treatment_effect._patients_group,
essential_covariates=self._treatment_effect._matching_essential_covariates,
one_hot_covariates=self._treatment_effect._matching_one_hot_covariates,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def calculate_propensity(
treated_test: NDArray[Union[np.int64, np.float64]],
control_test: NDArray[Union[np.int64, np.float64]],
model: Model = "logit",
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Calculates the propensity/probabilities of a subject being in the treated group.

Expand All @@ -55,8 +55,8 @@ def calculate_propensity(
control group to predict probabilities.
model (Model, optional): Classification algorithm to use. Options: "logit",
"dec_tree", "forest".
hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings.
Uses default if None.
hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter
settings. Uses default hyperparameters if None.

Returns:
Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive
Expand All @@ -67,7 +67,7 @@ class for treated and control groups.
last class for treated and control sets, e.g., ([0.], [0.]).
"""
propensity_model = PROP_MODEL[model]
pm = propensity_model(**hyperparam) if hyperparam else propensity_model()
pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model()
pm.fit(x_train, y_train)

# Predict the probability of the treated and control groups
Expand All @@ -82,7 +82,7 @@ def run_propensity_score(
control_set: pl.DataFrame,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
covariates: Optional[MedRecordAttributeInputList] = None,
) -> pl.DataFrame:
"""Executes Propensity Score matching using a specified classification algorithm.
Expand All @@ -101,7 +101,7 @@ def run_propensity_score(
Options include "logit", "dec_tree", "forest".
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model
tuning. Increases computation time if set. Uses default if None.
covariates (Optional[MedRecordAttributeInputList], optional): Features for
matching. Uses all if None.
Expand All @@ -125,7 +125,7 @@ def run_propensity_score(
y_train,
treated_array,
control_array,
hyperparam=hyperparam,
hyperparameters=hyperparameters,
model=model,
)

Expand Down
157 changes: 143 additions & 14 deletions medmodels/treatment_effect/matching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,94 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple, TypeAlias
from typing import (
TYPE_CHECKING,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeAlias,
)

import polars as pl

from medmodels.medrecord._overview import extract_attribute_summary
from medmodels.medrecord.medrecord import MedRecord

if TYPE_CHECKING:
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.querying import NodeOperand
from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex

MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"]


class Matching(ABC):
"""The Abstract Class for matching."""

number_of_neighbors: int

def __init__(self, number_of_neighbors: int) -> None:
"""Initializes the matching class.

Args:
number_of_neighbors (int): Number of nearest neighbors to find for each
treated patient.
"""
self.number_of_neighbors = number_of_neighbors

def _preprocess_data(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
one_hot_covariates: MedRecordAttributeInputList,
patients_group: Group,
essential_covariates: Optional[List[MedRecordAttribute]] = None,
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
one_hot_covariates: Optional[List[MedRecordAttribute]] = None,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Prepared the data for the matching algorithms.

Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of treated subjects.
treated_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates
that are essential for matching
one_hot_covariates (MedRecordAttributeInputList): Covariates that
are one-hot encoded for matching
patients_group (Group): The group of patients.
essential_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.

Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their
preprocessed covariates

Raises:
AssertionError: If the one-hot covariates are not in the essential
covariates.
"""
essential_covariates = [str(covariate) for covariate in essential_covariates]
if essential_covariates is None:
# If no essential covariates provided, use all attributes of patients group
nodes_attributes = medrecord.node[medrecord.nodes_in_group(patients_group)]
essential_covariates = list(
{key for attributes in nodes_attributes.values() for key in attributes}
)

control_set = self._check_nodes(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
essential_covariates=essential_covariates,
)

if "id" not in essential_covariates:
essential_covariates.append("id")

# Dataframe
# Dataframe with the essential covariates
data = pl.DataFrame(
data=[
{"id": k, **v}
Expand All @@ -60,6 +104,33 @@ def _preprocess_data(
)
original_columns = data.columns

# If no one-hot covariates provided, use all categorical attributes of patients
if one_hot_covariates is None:
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
attributes = extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
one_hot_covariates = [
covariate
for covariate, values in attributes.items()
if "Categorical" in values["type"]
]

one_hot_covariates = [
covariate
for covariate in one_hot_covariates
if covariate in essential_covariates
]

# If there are one-hot covariates, check if all are in the essential covariates
if (
not all(
covariate in essential_covariates for covariate in one_hot_covariates
)
and one_hot_covariates
):
msg = "One-hot covariates must be in the essential covariates"
raise AssertionError(msg)

# One-hot encode the categorical variables
data = data.to_dummies(
columns=[str(covariate) for covariate in one_hot_covariates],
Expand All @@ -79,25 +150,83 @@ def _preprocess_data(

return data_treated, data_control

def _check_nodes(
self,
medrecord: MedRecord,
treated_set: Set[NodeIndex],
control_set: Set[NodeIndex],
essential_covariates: List[MedRecordAttribute],
) -> Set[NodeIndex]:
"""Check if the treated and control sets are disjoint.

Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (List[MedRecordAttribute]): Covariates that are
essential for matching.

Returns:
Set[NodeIndex]: The control set.

Raises:
ValueError: If not enough control subjects to match the treated subjects.
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
ValueError: If some treated nodes do not have all the essential covariates.
"""

def query_essential_covariates(
node: NodeOperand, patients_set: Set[NodeIndex]
) -> None:
"""Query the nodes that have all the essential covariates."""
node.has_attribute(essential_covariates)

node.index().is_in(list(patients_set))

control_set = set(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, control_set)
)
)

if len(control_set) < self.number_of_neighbors * len(treated_set):
msg = (
f"Not enough control subjects to match the treated subjects. "
f"Number of controls: {len(control_set)}, "
f"Number of treated subjects: {len(treated_set)}, "
f"Number of neighbors required per treated subject: {self.number_of_neighbors}, "
f"Total controls needed: {self.number_of_neighbors * len(treated_set)}."
)
raise ValueError(msg)

if len(treated_set) != len(
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
medrecord.select_nodes(
lambda node: query_essential_covariates(node, treated_set)
)
):
msg = "Some treated nodes do not have all the essential covariates"
raise ValueError(msg)

return control_set

@abstractmethod
def match_controls(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
essential_covariates: Optional[Sequence[MedRecordAttribute]] = None,
one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None,
) -> Set[NodeIndex]:
"""Matches the controls based on the matching algorithm.

Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of control subjects.
treated_set (Set[NodeIndex]): Set of treated subjects.
essential_covariates (Optional[MedRecordAttributeInputList], optional):
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList], optional):
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None.

Returns:
Expand Down
38 changes: 19 additions & 19 deletions medmodels/treatment_effect/matching/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Optional, Sequence, Set

from medmodels.treatment_effect.matching.algorithms.classic_distance_models import (
nearest_neighbor,
Expand All @@ -11,7 +11,7 @@

if TYPE_CHECKING:
from medmodels import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex


class NeighborsMatching(Matching):
Expand All @@ -36,44 +36,44 @@ def __init__(
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
"""
self.number_of_neighbors = number_of_neighbors
super().__init__(number_of_neighbors)

def match_controls(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
patients_group: Group,
essential_covariates: Optional[Sequence[MedRecordAttribute]] = None,
one_hot_covariates: Optional[Sequence[MedRecordAttribute]] = None,
) -> Set[NodeIndex]:
"""Matches the controls based on the nearest neighbor algorithm.

Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (Optional[MedRecordAttributeInputList], optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (Optional[MedRecordAttributeInputList], optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
treated_set (Set[NodeIndex]): Set of treated subjects.
patients_group (Group): Group of patients in the MedRecord.
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.

Returns:
Set[NodeIndex]: Node Ids of the matched controls.
"""
if essential_covariates is None:
essential_covariates = ["gender", "age"]
if one_hot_covariates is None:
one_hot_covariates = ["gender"]

data_treated, data_control = self._preprocess_data(
medrecord=medrecord,
control_set=control_set,
treated_set=treated_set,
essential_covariates=essential_covariates,
one_hot_covariates=one_hot_covariates,
patients_group=patients_group,
essential_covariates=list(essential_covariates)
if essential_covariates
else None,
one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None,
)

# Run the algorithm to find the matched controls
Expand Down
Loading