-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
36 lines (24 loc) · 990 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import List
import pandas as pd
from pandas import Series
from src.plots import ScorePlotter
from src.prepare_data import PrepData
from src.propensity_matching import PScorer
from src.find_similarities import ObsMatcher
PS_GROUP = 'acquirer'
TARGET = 'target'
FILE_PATH = 'data/df.csv'
if __name__ == "__main__":
data = PrepData(FILE_PATH, group=PS_GROUP, target=TARGET, index_col="id")
scorer = PScorer()
scorer.fit(data.input, data.group_label)
ps_scores: Series = scorer.predict(data.input)
ScorePlotter.plot_roc_curve(ps_scores, data.group_label)
matcher = ObsMatcher(n_matches=1, caliper=0.001)
matched_index: List[int] = matcher.match_scores(ps_scores, data.group_label)
ScorePlotter.plot_smd_comparison(
data=data.input,
matched_index=matched_index,
treatment=data.group_label
)
matched_data = data.input[data.input.index.isin(matched_index)].join(data.target_label).join(data.group_label)