Skip to content

Commit

Permalink
fix: PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MarIniOnz committed Jan 20, 2025
1 parent c1dcd9e commit 39ad1cd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 26 deletions.
23 changes: 10 additions & 13 deletions medmodels/treatment_effect/matching/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,25 @@ def match_controls(
control_set (Set[NodeIndex]): Set of control subjects.
treated_set (Set[NodeIndex]): Set of treated subjects.
patients_group (Group): Group of patients in the MedRecord.
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
essential_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.
Returns:
Set[NodeIndex]: Node Ids of the matched controls.
"""
if essential_covariates is None:
essential_covariates = ["gender", "age"]
if one_hot_covariates is None:
one_hot_covariates = ["gender"]

data_treated, data_control = self._preprocess_data(
medrecord=medrecord,
control_set=control_set,
treated_set=treated_set,
patients_group=patients_group,
essential_covariates=list(essential_covariates),
one_hot_covariates=list(one_hot_covariates),
essential_covariates=list(essential_covariates)
if essential_covariates
else None,
one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None,
)

# Run the algorithm to find the matched controls
Expand Down
23 changes: 10 additions & 13 deletions medmodels/treatment_effect/matching/propensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,26 @@ def match_controls(
control_set (Set[NodeIndex]): Set of control subjects.
treated_set (Set[NodeIndex]): Set of treated subjects.
patients_group (Group): Group of patients in MedRecord.
essential_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (Optional[Sequence[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
essential_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[List[MedRecordAttribute]], optional):
Covariates that are one-hot encoded for matching. Defaults to None,
meaning all the categorical attributes of the patients are used.
Returns:
Set[NodeIndex]: Node Ids of the matched controls.
"""
if essential_covariates is None:
essential_covariates = ["gender", "age"]
if one_hot_covariates is None:
one_hot_covariates = ["gender"]

# Preprocess the data
data_treated, data_control = self._preprocess_data(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
patients_group=patients_group,
essential_covariates=list(essential_covariates),
one_hot_covariates=list(one_hot_covariates),
essential_covariates=list(essential_covariates)
if essential_covariates
else None,
one_hot_covariates=list(one_hot_covariates) if one_hot_covariates else None,
)

# Convert the Polars DataFrames to NumPy arrays
Expand Down

0 comments on commit 39ad1cd

Please sign in to comment.