Skip to content

Commit

Permalink
fix: fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MarIniOnz committed Dec 17, 2024
1 parent 866207a commit d169a70
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
52 changes: 34 additions & 18 deletions medmodels/treatment_effect/matching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def _preprocess_data(
treated_set (Set[NodeIndex]): Set of control subjects.
patients_group (Group): The group of patients.
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
Covariates that are essential for matching. Defaults to None, meaning
all the attributes of the patients are used.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
Expand All @@ -72,8 +73,6 @@ def _preprocess_data(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
)
else:
essential_covariates = [covariate for covariate in essential_covariates]

control_set = self._check_nodes(
medrecord=medrecord,
Expand All @@ -83,9 +82,9 @@ def _preprocess_data(
)

if "id" not in essential_covariates:
essential_covariates.append("id")
essential_covariates.append("id") # type: ignore

Check failure on line 85 in medmodels/treatment_effect/matching/matching.py

View workflow job for this annotation

GitHub Actions / lint

Argument of type "Literal['id']" cannot be assigned to parameter "object" of type "int" in function "append"   "Literal['id']" is incompatible with "int" (reportArgumentType)

# Dataframe wth the essential covariates
# Dataframe with the essential covariates
data = pl.DataFrame(
data=[
{"id": k, **v}
Expand All @@ -105,8 +104,18 @@ def _preprocess_data(
if "values" in values
]

if not all(
covariate in essential_covariates for covariate in one_hot_covariates
one_hot_covariates = [
covariate
for covariate in one_hot_covariates
if covariate in essential_covariates
]

# If there are one-hot covariates, check if all are in the essential covariates
if (
not all(
covariate in essential_covariates for covariate in one_hot_covariates
)
and one_hot_covariates
):
raise AssertionError(
"One-hot covariates must be in the essential covariates"
Expand All @@ -121,8 +130,8 @@ def _preprocess_data(

# Add to essential covariates the new columns created by one-hot encoding and
# delete the ones that were one-hot encoded
essential_covariates.extend(new_columns)
[essential_covariates.remove(col) for col in one_hot_covariates]
essential_covariates.extend(new_columns) # type: ignore

Check failure on line 133 in medmodels/treatment_effect/matching/matching.py

View workflow job for this annotation

GitHub Actions / lint

Argument of type "list[str]" cannot be assigned to parameter "iterable" of type "Iterable[int]" in function "extend"   "list[str]" is incompatible with "Iterable[int]"     Type parameter "_T_co@Iterable" is covariant, but "str" is not a subtype of "int"       "str" is incompatible with "int" (reportArgumentType)
[essential_covariates.remove(col) for col in one_hot_covariates] # type: ignore

Check failure on line 134 in medmodels/treatment_effect/matching/matching.py

View workflow job for this annotation

GitHub Actions / lint

Argument of type "MedRecordAttribute" cannot be assigned to parameter "value" of type "int" in function "remove"   Type "MedRecordAttribute" is incompatible with type "int"     "str" is incompatible with "int" (reportArgumentType)

Check failure on line 134 in medmodels/treatment_effect/matching/matching.py

View workflow job for this annotation

GitHub Actions / lint

Argument of type "MedRecordAttribute" cannot be assigned to parameter "value" of type "str" in function "remove"   Type "MedRecordAttribute" is incompatible with type "str"     "int" is incompatible with "str" (reportArgumentType)
data = data.select(essential_covariates)

# Select the sets of treated and control subjects
Expand Down Expand Up @@ -152,6 +161,9 @@ def _check_nodes(
Raises:
ValueError: If not enough control subjects to match the treated subjects.
ValueError: If some treated nodes do not have all the essential covariates
or if there are not enough control subjects to match the treated
subjects.
"""

def query_essential_covariates(
Expand All @@ -163,23 +175,27 @@ def query_essential_covariates(

node.index().is_in(list(patients_set))

control_set = set(
if len(treated_set) != len(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, control_set)
lambda node: query_essential_covariates(node, treated_set)
)
)
if len(control_set) < self.number_of_neighbors * len(treated_set):
):
raise ValueError(
"Not enough control subjects to match the treated subjects"
"Some treated nodes do not have all the essential covariates"
)

if len(treated_set) != len(
control_set = set(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, treated_set)
lambda node: query_essential_covariates(node, control_set)
)
):
)
if len(control_set) < self.number_of_neighbors * len(treated_set):
raise ValueError(
"Some treated nodes do not have all the essential covariates"
f"Not enough control subjects to match the treated subjects. "
f"Number of controls: {len(control_set)}, "
f"Number of treated subjects: {len(treated_set)}, "
f"Number of neighbors required per treated subject: {self.number_of_neighbors}, "
f"Total controls needed: {self.number_of_neighbors * len(treated_set)}."
)

return control_set
Expand Down
18 changes: 13 additions & 5 deletions medmodels/treatment_effect/matching/tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame:
"age": [20, 30, 40, 30, 40, 50, 60, 70, 80],
"gender": [
"male",
"female",
"male",
"female",
"male",
"female",
"male",
"male",
"female",
"female",
"male",
],
Expand Down Expand Up @@ -59,6 +59,7 @@ def create_medrecord(
patients = create_patients(patient_list=patient_list)
medrecord = MedRecord.from_pandas(nodes=[(patients, "index")])
medrecord.add_group(group="patients", nodes=patients["index"].to_list())
medrecord.add_nodes(("P10", {}), "patients")
return medrecord


Expand Down Expand Up @@ -144,14 +145,15 @@ def test_match_controls(self):
# Assert that the correct number of controls were matched
self.assertEqual(len(matched_controls), len(treated_set))

# Assert it works equally if no covariates are given (automatically assigned)
# It should work equally if no covariates are given (all attributes automatically assigned)
matched_controls_no_covariates_specified = neighbors_matching.match_controls(
medrecord=self.medrecord,
control_set=control_set,
treated_set=treated_set,
patients_group="patients",
)

self.assertEqual(matched_controls_no_covariates_specified, matched_controls)
self.assertTrue(matched_controls_no_covariates_specified.issubset(control_set))
self.assertEqual(
len(matched_controls_no_covariates_specified), len(treated_set)
Expand Down Expand Up @@ -188,7 +190,10 @@ def test_invalid_check_nodes(self):
)
self.assertEqual(
str(context.exception),
"Not enough control subjects to match the treated subjects",
"Not enough control subjects to match the treated subjects. "
+ "Number of controls: 1, Number of treated subjects: 3, "
+ "Number of neighbors required per treated subject: 1, "
+ "Total controls needed: 3.",
)

with self.assertRaises(ValueError) as context:
Expand All @@ -201,7 +206,10 @@ def test_invalid_check_nodes(self):
)
self.assertEqual(
str(context.exception),
"Not enough control subjects to match the treated subjects",
"Not enough control subjects to match the treated subjects. "
+ "Number of controls: 5, Number of treated subjects: 3, "
+ "Number of neighbors required per treated subject: 2, "
+ "Total controls needed: 6.",
)

# Test missing essential covariates in treated set
Expand Down

0 comments on commit d169a70

Please sign in to comment.