Skip to content

Commit

Permalink
update with ECOOT each and VAE feature matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Oct 10, 2024
1 parent e60170e commit d088eac
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 36 deletions.
8 changes: 7 additions & 1 deletion perturbot/cv/run_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def parse_args():
parser.add_argument("--log", type=str, default=None)
parser.add_argument("--rerun", action="store_true")
parser.add_argument("--all", action="store_true")
parser.add_argument("--load-existing", action="store_true")
parser.add_argument("--feature", action="store_true")
return parser.parse_args()

Expand All @@ -44,8 +45,13 @@ def parse_args():
)
elif args.all:
submit_all_run(
data_paths[args.data],
(
full_data_paths[args.data]
if "VAE" in args.method
else data_paths[args.data]
),
args.method,
load_existing=args.load_existing,
)
elif args.feature:
submit_feature_run(
Expand Down
71 changes: 52 additions & 19 deletions perturbot/perturbot/eval/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pickle as pkl
from functools import partial
import numpy as np

import torch
from perturbot.eval.match import get_FOSCTTM, get_diag_fracs

from perturbot.eval.utils import get_Ts_from_nn_multKs
from perturbot.match.cot_labels import get_coupling_cotl_sinkhorn

from perturbot.match.ott_egwl import (
Expand All @@ -17,13 +17,17 @@
get_coupling_leot_ott,
get_coupling_egw_ott,
)
from perturbot.match.cot import get_coupling_cot_sinkhorn
from perturbot.match.cot import (
get_coupling_cot_sinkhorn,
get_coupling_each_cot_sinkhorn,
)
from perturbot.match.gw_labels import get_coupling_egw_labels
from perturbot.predict.scvi_vae import train_vae_model
from perturbot.predict.scvi_vae import train_vae_model, infer_from_Xs, infer_from_Ys

ot_method_map = {
"ECOOTL": get_coupling_cotl_sinkhorn,
"ECOOT": get_coupling_cot_sinkhorn,
"ECOOT_each": get_coupling_each_cot_sinkhorn,
"EGWL": get_coupling_egw_labels,
"EOT_ott": get_coupling_eot_ott,
"LEOT_ott": get_coupling_leot_ott,
Expand Down Expand Up @@ -94,21 +98,50 @@ def main(args):
train_Z = Zs_dict
train_data = (train_X, train_Y)
print(f"Calculating matching with {match_eps}")
if args.log_filepath is not None:
with open(args.log_filepath, "rb") as f:
d = pkl.load(f)
Ts_matching = d["T"]
log_matching = ""
else:
Ts_matching, log_matching = ot_method_map[args.method](
train_data, match_eps
)
if "VAE" in args.method:
dim_X = data_dict["Xs_dict"][labels[0]].shape[1]
dim_Y = data_dict["Xt_dict"][labels[0]].shape[1]
latent_Y = infer_from_Ys(train_Y, Ts_matching, dim_X)
latent_X = infer_from_Xs(train_X, Ts_matching, dim_Y)
_, mean_foscttm = get_FOSCTTM(
Ts_matching,
latent_X,
latent_Y,
use_agg="mean",
use_barycenter=False,
)
ks = [1, 5, 10, 50, 100]
k_to_Ts = get_Ts_from_nn_multKs(latent_X, latent_Y, ks) # k -> T
dfracs = {}
rel_dfracs = {}
for k, Ts in k_to_Ts.items():
dfracs[k], rel_dfracs[k] = get_diag_fracs(
Ts, train_X, train_Y, train_Z, train_Z
)

Ts_matching, log_matching = ot_method_map[args.method](train_data, match_eps)
if isinstance(Ts_matching, dict):
total_sum = 0
for k, v in Ts_matching.items():
total_sum += v.sum()
Ts_matching = {
k: v.astype(np.double) / total_sum for k, v in Ts_matching.items()
}
else:
Ts_matching = Ts_matching.astype(np.double) / Ts_matching.sum()
_, mean_foscttm = get_FOSCTTM(Ts_matching, train_X, train_Y, use_agg="mean")
dfracs, rel_dfracs = get_diag_fracs(
Ts_matching, train_X, train_Y, train_Z, train_Z
)
if isinstance(Ts_matching, dict):
total_sum = 0
for k, v in Ts_matching.items():
total_sum += v.sum()
Ts_matching = {
k: v.astype(np.double) / total_sum for k, v in Ts_matching.items()
}
else:
Ts_matching = Ts_matching.astype(np.double) / Ts_matching.sum()
_, mean_foscttm = get_FOSCTTM(Ts_matching, train_X, train_Y, use_agg="mean")
dfracs, rel_dfracs = get_diag_fracs(
Ts_matching, train_X, train_Y, train_Z, train_Z
)

# if not all_to_all:
mean_mean_foscttm = mean_foscttm.mean()
Expand All @@ -127,14 +160,14 @@ def main(args):
logs["log"] = (log_matching,)
except Exception as e:
with open(f"all_{args.method}.{args.eps}.tmp.pkl", "wb") as f:
pkl.dump(logs, f)
pkl.dump({"log": logs, "T": Ts_matching}, f)
raise e

with open(f"all_{args.method}.{args.eps}.pkl", "wb") as f:
pkl.dump(logs, f)


def submit_all_run(data_path, ot_method_label):
def submit_all_run(data_path, ot_method_label, load_existing=False):
epsilons = [1e-2, 1e-3, 1e-4, 1e-5]
for eps in epsilons:
run_label = f"all.{ot_method_label}.{eps}"
Expand Down
12 changes: 10 additions & 2 deletions perturbot/perturbot/eval/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from functools import partial
from itertools import product
import argparse
import os

os.environ["OPENBLAS_NUM_THREADS"] = "1"
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
Expand Down Expand Up @@ -32,7 +35,11 @@
get_coupling_leot_ott,
get_coupling_egw_ott,
)
from perturbot.match.cot import get_coupling_cot, get_coupling_cot_sinkhorn
from perturbot.match.cot import (
get_coupling_cot,
get_coupling_cot_sinkhorn,
get_coupling_each_cot_sinkhorn,
)
from perturbot.match.gw_labels import get_coupling_gw_labels, get_coupling_egw_labels
from perturbot.predict.scvi_vae import train_vae_model

Expand All @@ -49,6 +56,7 @@
"ECOOTL": get_coupling_cotl_sinkhorn,
"COOT": get_coupling_cot,
"ECOOT": get_coupling_cot_sinkhorn,
"ECOOT_each": get_coupling_each_cot_sinkhorn,
"GW_all": get_coupling_gw_all,
# "EGW_all": get_coupling_egw_all,
"GW": get_coupling_gw_cg,
Expand All @@ -72,7 +80,7 @@
1e-5,
1e-6,
]
for method in ["ECOOTL", "ECOOT"]:
for method in ["ECOOTL", "ECOOT", "ECOOT_each"]:
ot_method_hyperparams[method] = [0.1, 0.05, 0.01, 0.005, 0.001]
ot_method_all_to_all = ["GW_all", "EGW_all_ott", "EOT_all_ott"]

Expand Down
23 changes: 21 additions & 2 deletions perturbot/perturbot/eval/cv_inner_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import argparse
import pickle as pkl
import jax
from functools import partial
from itertools import product
from multiprocessing import Pool
Expand All @@ -12,7 +13,9 @@
from perturbot.eval.utils import get_Ts_from_nn_multKs
from sklearn.model_selection import KFold
import torch

from tqdm import tqdm
import psutil
import gc
from perturbot.eval.utils import _pop_keys, _pop_key
from perturbot.eval.match import get_FOSCTTM, get_diag_fracs
from perturbot.eval.prediction import get_evals_preds, get_evals
Expand All @@ -36,7 +39,11 @@
get_coupling_leot_ott,
get_coupling_egw_ott,
)
from perturbot.match.cot import get_coupling_cot, get_coupling_cot_sinkhorn
from perturbot.match.cot import (
get_coupling_cot,
get_coupling_cot_sinkhorn,
get_coupling_each_cot_sinkhorn,
)
from perturbot.match.gw_labels import get_coupling_gw_labels, get_coupling_egw_labels
from perturbot.predict.scvi_vae import train_vae_model
from perturbot.predict.linear_regression import (
Expand All @@ -51,6 +58,7 @@

ot_method_map = {
"ECOOTL": get_coupling_cotl_sinkhorn,
"ECOOT_each": get_coupling_each_cot_sinkhorn,
"ECOOT": get_coupling_cot_sinkhorn,
"EGWL": get_coupling_egw_labels,
"EOT_ott": get_coupling_eot_ott,
Expand Down Expand Up @@ -100,6 +108,7 @@ def parse_args():
"EGW_all_ott",
"EGWL_ott",
"ECOOT",
"ECOOT_each",
"ECOOTL",
]:
ot_method_hyperparams[method] = [
Expand Down Expand Up @@ -175,13 +184,23 @@ def main(args):
train_Z_prod = train_Zs * len(epsilons)
print("Len eps", eps_prod)
if args.log_filepath is None and "VAE" not in args.method:
Ts_list = []
logs = []
# for i, (_train_data, _eps) in tqdm(enumerate(zip(train_data_prod, eps_prod))):
# print(f"{i}th iteration with {_eps}")
# _T, _log = ot_method_map[args.method](_train_data, _eps)
# Ts_list.append(_T)
# logs.append(_log)
# jax.clear_caches()
try:
with Pool(5 * len(epsilons)) as p:
# with Pool(5) as p:
Ts_list, logs = zip(
*p.starmap(
ot_method_map[args.method], zip(train_data_prod, eps_prod)
)
)

except KeyboardInterrupt:
p.terminate()
finally:
Expand Down
44 changes: 38 additions & 6 deletions perturbot/perturbot/eval/feature_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from perturbot.match.fot import get_coupling_fot
from perturbot.eval.utils import make_G
from perturbot.predict.scvi_vae import train_vae_model, infer_from_Xs, infer_from_Ys
from perturbot.eval.utils import get_Ts_from_nn_multKs


def parse_args():
Expand All @@ -33,6 +35,12 @@ def parse_args():
default=None,
help="One of perfect, random, by_conc",
)
parser.add_argument(
"--best-k",
type=int,
default=None,
help="Best k to use for VAE",
)
return parser.parse_args()


Expand Down Expand Up @@ -63,6 +71,15 @@ def main(args):
X_dict = data_dict["Xs_dict"]
Y_dict = data_dict["Xt_dict"]
Z_dict = data_dict["Zs_dict"]["dosage"]

if "VAE" in args.method:
labels = list(X_dict.keys())
dim_X = data_dict["Xs_dict"][labels[0]].shape[1]
dim_Y = data_dict["Xt_dict"][labels[0]].shape[1]
latent_Y = infer_from_Ys(Y_dict, Ts, dim_X)
latent_X = infer_from_Xs(X_dict, Ts, dim_Y)
Ts = get_Ts_from_nn_multKs(latent_X, latent_Y, [args.best_k])[args.best_k]

if Ts is None:
if args.method == "random":
Ts = {
Expand Down Expand Up @@ -100,21 +117,36 @@ def submit_feature_run(data_path, ot_method_label):
if ot_method_label in ["perfect", "random", "by_conc"]:
best_eps = 0
else:
best_k_dict = {}
for eps in epsilons:
try:
with open(f"all_{ot_method_label}.{eps}.pkl", "rb") as f:
d = pkl.load(f)
rel_dfracs.append(d["matching_evals"][0]["rel_dfracs"])
except:
rel_dfracs.append(-1)
# try:
with open(f"all_{ot_method_label}.{eps}.pkl", "rb") as f:
d = pkl.load(f)
_rel_dfracs = d["matching_evals"][0]["rel_dfracs"]
if isinstance(_rel_dfracs, dict):
max_rel_dfracs = -10
for k, v in _rel_dfracs.items():
if max_rel_dfracs < v:
max_rel_dfracs = v
best_k_dict[eps] = k
_rel_dfracs = max_rel_dfracs
rel_dfracs.append(_rel_dfracs)
# except Exception as e:
# print(e)
# rel_dfracs.append(-1)
best_eps = epsilons[rel_dfracs.index(max(rel_dfracs))]
if "VAE" in ot_method_label:
best_k = best_k_dict[best_eps]
for eps in epsilons:
run_label = f"FM.{ot_method_label}.{eps}"
f = open(f"{run_label}.bsub", "w")
f.write("source ~/.bashrc\n")
f.write("pwd\n")
f.write("conda activate ot \n")
command = f"python /gpfs/scratchfs01/site/u/ryuj6/OT/software/perturbot/perturbot/eval/feature_matching.py {ot_method_label} {data_path} {best_eps} {eps}"

if "VAE" in ot_method_label:
command += f" --best-k {best_k}"
f.write(f"echo {command}\n")
f.write(f"{command}\n")
f.close()
Expand Down
Loading

0 comments on commit d088eac

Please sign in to comment.