From 8e4a5bffdd607516f270694cc2aaf83bc5732dd8 Mon Sep 17 00:00:00 2001 From: Oliver Spohngellert Date: Mon, 7 Oct 2024 14:08:16 -0400 Subject: [PATCH] Control to Treatment Matching (#797) * Many to one matching without replacement. * feat: first draft allowing control to treatment. * fix: writing tests, fixing code. * fix: formatting. --- causalml/match.py | 58 +++++++++++++++++++++++++-------------------- tests/test_match.py | 21 ++++++++++++++++ 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/causalml/match.py b/causalml/match.py index 91cc64c0..395a0777 100644 --- a/causalml/match.py +++ b/causalml/match.py @@ -91,6 +91,8 @@ class NearestNeighborMatch: ratio (int): ratio of control / treatment to be matched. shuffle (bool): whether to shuffle the treatment group data before matching + treatment_to_control (bool): whether to match treatment to control + or control to treatment random_state (numpy.random.RandomState or int): RandomState or an int seed n_jobs (int): The number of parallel jobs to run for neighbors search. @@ -103,6 +105,7 @@ def __init__( replace=False, ratio=1, shuffle=True, + treatment_to_control=True, random_state=None, n_jobs=-1, ): @@ -123,6 +126,7 @@ def __init__( self.replace = replace self.ratio = ratio self.shuffle = shuffle + self.treatment_to_control = treatment_to_control self.random_state = check_random_state(random_state) self.n_jobs = n_jobs @@ -144,16 +148,19 @@ def match(self, data, treatment_col, score_cols): treatment = data.loc[data[treatment_col] == 1, score_cols] control = data.loc[data[treatment_col] == 0, score_cols] + # Picks whether to use treatment or control for matching direction + match_from = treatment if self.treatment_to_control else control + match_to = control if self.treatment_to_control else treatment sdcal = self.caliper * np.std(data[score_cols].values) if self.replace: scaler = StandardScaler() scaler.fit(data[score_cols]) - treatment_scaled = pd.DataFrame( - scaler.transform(treatment), index=treatment.index + match_from_scaled = pd.DataFrame( + scaler.transform(match_from), index=match_from.index ) - control_scaled = pd.DataFrame( - scaler.transform(control), index=control.index + match_to_scaled = pd.DataFrame( + scaler.transform(match_to), index=match_to.index ) # SD is the same as caliper because we use a StandardScaler above @@ -162,21 +169,20 @@ def match(self, data, treatment_col, score_cols): matching_model = NearestNeighbors( n_neighbors=self.ratio, n_jobs=self.n_jobs ) - matching_model.fit(control_scaled) - distances, indices = matching_model.kneighbors(treatment_scaled) - + matching_model.fit(match_to_scaled) + distances, indices = matching_model.kneighbors(match_from_scaled) # distances and indices are (n_obs, self.ratio) matrices. # To index easily, reshape distances, indices and treatment into # the (n_obs * self.ratio, 1) matrices and data frame. distances = distances.T.flatten() indices = indices.T.flatten() - treatment_scaled = pd.concat([treatment_scaled] * self.ratio, axis=0) + match_from_scaled = pd.concat([match_from_scaled] * self.ratio, axis=0) cond = (distances / np.sqrt(len(score_cols))) < sdcal # Deduplicate the indices of the treatment group - t_idx_matched = np.unique(treatment_scaled.loc[cond].index) + from_idx_matched = np.unique(match_from_scaled.loc[cond].index) # XXX: Should we deduplicate the indices of the control group too? - c_idx_matched = np.array(control_scaled.iloc[indices[cond]].index) + to_idx_matched = np.array(match_to_scaled.iloc[indices[cond]].index) else: assert len(score_cols) == 1, ( "Matching on multiple columns is only supported using the " @@ -187,31 +193,31 @@ def match(self, data, treatment_col, score_cols): score_col = score_cols[0] if self.shuffle: - t_indices = self.random_state.permutation(treatment.index) + from_indices = self.random_state.permutation(match_from.index) else: - t_indices = treatment.index + from_indices = match_from.index - t_idx_matched = [] - c_idx_matched = [] - control["unmatched"] = True + from_idx_matched = [] + to_idx_matched = [] + match_to["unmatched"] = True - for t_idx in t_indices: + for from_idx in from_indices: dist = np.abs( - control.loc[control.unmatched, score_col] - - treatment.loc[t_idx, score_col] + match_to.loc[match_to.unmatched, score_col] + - match_from.loc[from_idx, score_col] ) # Gets self.ratio lowest dists - c_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio] - c_idx_list = dist.index[c_np_idx_list] - for i, c_idx in enumerate(c_idx_list): - if dist[c_idx] <= sdcal: + to_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio] + to_idx_list = dist.index[to_np_idx_list] + for i, to_idx in enumerate(to_idx_list): + if dist[to_idx] <= sdcal: if i == 0: - t_idx_matched.append(t_idx) - c_idx_matched.append(c_idx) - control.loc[c_idx, "unmatched"] = False + from_idx_matched.append(from_idx) + to_idx_matched.append(to_idx) + match_to.loc[to_idx, "unmatched"] = False return data.loc[ - np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)]) + np.concatenate([np.array(from_idx_matched), np.array(to_idx_matched)]) ] def match_by_group(self, data, treatment_col, score_cols, groupby_col): diff --git a/tests/test_match.py b/tests/test_match.py index 4933ee10..e5785e44 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -57,6 +57,27 @@ def test_nearest_neighbor_match_by_group(generate_unmatched_data): assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0) +def test_nearest_neighbor_match_control_to_treatment(generate_unmatched_data): + """ + Tests whether control to treatment matching is working. Does so + by using: + + replace=True + treatment_to_control=False + ratio=2 + + + And testing if we get 2x the number of control matches than treatment + """ + df, features = generate_unmatched_data() + + psm = NearestNeighborMatch( + replace=True, ratio=2, treatment_to_control=False, random_state=RANDOM_SEED + ) + matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL]) + assert 2 * sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0) + + def test_match_optimizer(generate_unmatched_data): df, features = generate_unmatched_data()