diff --git a/model4/README.md b/model4/README.md new file mode 100644 index 0000000..23698b3 --- /dev/null +++ b/model4/README.md @@ -0,0 +1,6 @@ +## To train models + +`python run_models.py --config configs/config_ccle.yaml` + +This command will the model using four different drug encoders for 10 different data splits. +Will add more details.. \ No newline at end of file diff --git a/model4/configs/config_ccle.yaml b/model4/configs/config_ccle.yaml new file mode 100644 index 0000000..412e934 --- /dev/null +++ b/model4/configs/config_ccle.yaml @@ -0,0 +1,5 @@ +epochs: 100 +data_type: CCLE +metric: auc +out_dir: cmp_ccle +cuda_id: 1 diff --git a/model4/configs/config_ccle_gexp.yaml b/model4/configs/config_ccle_gexp.yaml new file mode 100644 index 0000000..a8a9803 --- /dev/null +++ b/model4/configs/config_ccle_gexp.yaml @@ -0,0 +1,6 @@ +epochs: 100 +data_type: CCLE +metric: auc +out_dir: ccle_gexp +cuda_id: 4 +use_proteomics_data: 0 # 0: use gene expression, 1: use proteomics diff --git a/model4/configs/config_ccle_prot.yaml b/model4/configs/config_ccle_prot.yaml new file mode 100644 index 0000000..18507c4 --- /dev/null +++ b/model4/configs/config_ccle_prot.yaml @@ -0,0 +1,5 @@ +epochs: 100 +data_type: CCLE +metric: auc +out_dir: ccle_prot +cuda_id: 2 diff --git a/model4/configs/config_ctrpv2.yaml b/model4/configs/config_ctrpv2.yaml new file mode 100644 index 0000000..d24dfa1 --- /dev/null +++ b/model4/configs/config_ctrpv2.yaml @@ -0,0 +1,5 @@ +epochs: 100 +data_type: CTRPv2 +metric: auc +out_dir: cmp_ctrpv2 +cuda_id: 4 diff --git a/model4/configs/config_ctrpv2_gexp.yaml b/model4/configs/config_ctrpv2_gexp.yaml new file mode 100644 index 0000000..5f2121d --- /dev/null +++ b/model4/configs/config_ctrpv2_gexp.yaml @@ -0,0 +1,6 @@ +epochs: 100 +data_type: CTRPv2 +metric: auc +out_dir: ctrpv2_gexp +cuda_id: 4 +use_proteomics_data: 0 # 0: use gene expression, 1: use proteomics diff --git a/model4/configs/config_ctrpv2_prot.yaml b/model4/configs/config_ctrpv2_prot.yaml new file mode 100644 index 0000000..60784be --- /dev/null +++ b/model4/configs/config_ctrpv2_prot.yaml @@ -0,0 +1,5 @@ +epochs: 100 +data_type: CTRPv2 +metric: auc +out_dir: ctrpv2_prot +cuda_id: 3 diff --git a/model4/data_utils.py b/model4/data_utils.py new file mode 100644 index 0000000..4666365 --- /dev/null +++ b/model4/data_utils.py @@ -0,0 +1,459 @@ +import os +from typing import Union, List +import urllib +import pandas as pd +from rdkit import Chem +#------------------ +# 1. download data +#------------------ + +# version = 'benchmark-data-pilot1' +# ftp_dir = f'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/{version}/csa_data/raw_data/' +# version = 'benchmark-data-imp-2023' +# ftp_dir = f'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/{version}/csa_data/' + +candle_data_dict = { + 'ccle_candle': "CCLE", + 'ctrpv2_candle':"CTRPv2", + 'gdscv1_candle':"GDSCv1", + 'gdscv2_candle':"GDSCv2", + 'gcsi_candle': "gCSI"} + + +class Downloader: + + def __init__(self, version): + self.version = version + if version == 'benchmark-data-pilot1': + self.ftp_dir = f'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/{version}/csa_data/raw_data/' + elif version == 'benchmark-data-imp-2023': + self.ftp_dir = f'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/{version}/csa_data/' + + def download_candle_data(self, data_type="CCLE", split_id=100, data_dest='Data/'): + self.download_candle_split_data(data_type=data_type, split_id=split_id, data_dest=data_dest) + self.download_candle_resp_data(data_dest=data_dest) + self.download_candle_gexp_data(data_dest=data_dest) + self.download_candle_mut_data(data_dest=data_dest) + self.download_candle_smiles_data(data_dest=data_dest) + self.download_candle_drug_ecfp4_data(data_dest=data_dest) + self.download_landmark_genes(data_dest=data_dest) + + + def download_deepttc_vocabs(self, data_dest='Data/'): + + src_dir = 'https://raw.githubusercontent.com/jianglikun/DeepTTC/main/ESPF/' + fnames = ['drug_codes_chembl_freq_1500.txt', 'subword_units_map_chembl_freq_1500.csv'] + + for fname in fnames: + src = os.path.join(src_dir, fname) + dest = os.path.join(data_dest, fname) + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + + + def download_candle_split_data(self, data_type="CCLE", split_id=0, data_dest='Data/'): + print(f'downloading {data_type} split {split_id} data') + + + split_src = os.path.join(self.ftp_dir, 'splits') + train_split_name = f'{data_type}_split_{split_id}_train.txt' + val_split_name = f'{data_type}_split_{split_id}_val.txt' + test_split_name = f'{data_type}_split_{split_id}_test.txt' + + + # download split data + for file in [train_split_name, val_split_name, test_split_name]: + src = os.path.join(split_src, file) + dest = os.path.join(data_dest, file) + + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_candle_resp_data(self, data_dest='Data/'): + # ftp_dir = 'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/benchmark-data-pilot1/csa_data/raw_data/' + + print('downloading response data') + # download response data + resp_name = 'response.tsv' + if self.version=='benchmark-data-imp-2023': + resp_name = 'response.txt' + + src = os.path.join(self.ftp_dir, 'y_data', resp_name) + dest = os.path.join(data_dest, resp_name) + + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_candle_gexp_data(self, data_dest='Data/'): + print('downloading expression data') + gexp_name = 'cancer_gene_expression.tsv' + if self.version=='benchmark-data-imp-2023': + gexp_name = 'cancer_gene_expression.txt' + + src = os.path.join(self.ftp_dir, 'x_data', gexp_name) + dest = os.path.join(data_dest, gexp_name) + + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_candle_mut_data(self, data_dest='Data/'): + # ftp_dir = 'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/benchmark-data-pilot1/csa_data/raw_data/' + # gene mutation data + print('downloading mutation data') + gmut_name = 'cancer_mutation_count.tsv' + if self.version=='benchmark-data-imp-2023': + gmut_name = 'cancer_mutation_count.txt' + + src = os.path.join(self.ftp_dir, 'x_data', gmut_name) + dest = os.path.join(data_dest, gmut_name) + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_candle_smiles_data(self, data_dest='Data/'): + # ftp_dir = 'https://ftp.mcs.anl.gov/pub/candle/public/improve/benchmarks/single_drug_drp/benchmark-data-pilot1/csa_data/raw_data/' + # gene mutation data + print('downloading smiles data') + + smiles_name = 'drug_SMILES.tsv' + if self.version=='benchmark-data-imp-2023': + smiles_name = 'drug_SMILES.txt' + + src = os.path.join(self.ftp_dir, 'x_data', smiles_name) + dest = os.path.join(data_dest, smiles_name) + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_candle_drug_ecfp4_data(self, data_dest='Data/'): + # gene mutation data + print('downloading drug_ecfp4 data') + name = 'drug_ecfp4_nbits512.tsv' + if self.version=='benchmark-data-imp-2023': + name = 'drug_ecfp4_512bit.txt' + + src = os.path.join(self.ftp_dir, 'x_data', name) + dest = os.path.join(data_dest, name) + if not os.path.exists(dest): + urllib.request.urlretrieve(src, dest) + + def download_landmark_genes(self, data_dest='Data/'): + urllib.request.urlretrieve('https://raw.githubusercontent.com/gihanpanapitiya/GraphDRP/to_candle/landmark_genes', data_dest+'/landmark_genes') + + +def add_smiles(smiles_df, df, metric): + + # df = rs_train.copy() + # smiles_df = data_utils.load_smiles_data(data_dir) + data_smiles_df = pd.merge(df, smiles_df, on = "improve_chem_id", how='left') + data_smiles_df = data_smiles_df.dropna(subset=[metric]) + data_smiles_df = data_smiles_df[['improve_sample_id', 'smiles', 'improve_chem_id', metric]] + data_smiles_df = data_smiles_df.drop_duplicates() + data_smiles_df.dropna(inplace=True) + data_smiles_df = data_smiles_df.reset_index(drop=True) + + return data_smiles_df + + +class DataProcessor: + def __init__(self, version): + self.version = version + + def load_drug_response_data(self, data_path, data_type="CCLE", + split_id=100, split_type='train', response_type='ic50', sep="\t", + dropna=True): + """ + Returns datarame with cancer ids, drug ids, and drug response values. Samples + from the original drug response file are filtered based on the specified + sources. + + Args: + source (str or list of str): DRP source name (str) or multiple sources (list of strings) + split(int or None): split id (int), None (load all samples) + split_type (str or None): one of the following: 'train', 'val', 'test' + y_col_name (str): name of drug response measure/score (e.g., AUC, IC50) + + Returns: + pd.Dataframe: dataframe that contains drug response values + """ + # TODO: at this point, this func implements the loading a single source + y_file_path = os.path.join(data_path, 'response.tsv') + if self.version=='benchmark-data-imp-2023': + y_file_path = os.path.join(data_path, 'response.txt') + + df = pd.read_csv(y_file_path, sep=sep) + + # import pdb; pdb.set_trace() + if isinstance(split_id, int): + # Get a subset of samples + ids = self.load_split_file(data_path, data_type, split_id, split_type) + df = df.loc[ids] + else: + # Get the full dataset for a given source + df = df[df["source"].isin([data_type])] + + cols = ["source", + "improve_chem_id", + "improve_sample_id", + response_type] + df = df[cols] # [source, drug id, cancer id, response] + if dropna: + df.dropna(axis=0, inplace=True) + df = df.reset_index(drop=True) + return df + + + def load_split_file(self, + data_path: str, + data_type: str, + split_id: Union[int, None]=None, + split_type: Union[str, List[str], None]=None) -> list: + """ + Args: + source (str): DRP source name (str) + + Returns: + ids (list): list of id integers + """ + if isinstance(split_type, str): + split_type = [split_type] + + # Check if the split file exists and load + ids = [] + for st in split_type: + fpath = os.path.join(data_path, f"{data_type}_split_{split_id}_{st}.txt") + # assert fpath.exists(), f"Splits file not found: {fpath}" + ids_ = pd.read_csv(fpath, header=None)[0].tolist() + ids.extend(ids_) + return ids + +#----------------------------------- +# 2. preprocess data to swnet format +#----------------------------------- +# def process_response_data(df_resp, response_type='ic50'): +# # df = pd.read_csv('response.tsv', sep='\t') +# drd = df_resp[['improve_chem_id', 'improve_sample_id', response_type]] +# drd.columns =['drug','cell_line','IC50'] +# # drd = drd.dropna(axis=0) +# drd.reset_index(drop=True, inplace=True) +# # drd.to_csv('tmp/Paccmann_MCA/Data/response.csv') +# return drd + + + def load_smiles_data(self, + data_dir, + sep: str="\t", + verbose: bool=True) -> pd.DataFrame: + """ + IMPROVE-specific func. + Read smiles data. + src_raw_data_dir : data dir where the raw DRP data is stored + """ + + smiles_path = os.path.join(data_dir, 'drug_SMILES.tsv') + if self.version=='benchmark-data-imp-2023': + smiles_path = os.path.join(data_dir, 'drug_SMILES.txt') + + df = pd.read_csv(smiles_path, sep=sep) + + # TODO: updated this after we update the data + df.columns = ["improve_chem_id", "smiles"] + + if verbose: + print(f"SMILES data: {df.shape}") + # print(df.dtypes) + # print(df.dtypes.value_counts()) + return df + + + + def load_morgan_fingerprint_data(self, + data_dir, + sep: str="\t", + verbose: bool=True) -> pd.DataFrame: + """ + Return Morgan fingerprints data. + """ + + path = os.path.join(data_dir, 'drug_ecfp4_nbits512.tsv') + if self.version=='benchmark-data-imp-2023': + path = os.path.join(data_dir, 'drug_ecfp4_512bit.txt') + + df = pd.read_csv(path, sep=sep) + df = df.set_index('improve_chem_id') + return df + + def load_gene_expression_data(self, + data_dir, + gene_system_identifier: Union[str, List[str]]="Gene_Symbol", + sep: str="\t", + verbose: bool=True) -> pd.DataFrame: + """ + Returns gene expression data. + + Args: + gene_system_identifier (str or list of str): gene identifier system to use + options: "Entrez", "Gene_Symbol", "Ensembl", "all", or any list + combination of ["Entrez", "Gene_Symbol", "Ensembl"] + + Returns: + pd.DataFrame: dataframe with the omic data + """ + gene_expression_file_path = os.path.join(data_dir, 'cancer_gene_expression.tsv') + if self.version=='benchmark-data-imp-2023': + gene_expression_file_path = os.path.join(data_dir, 'cancer_gene_expression.txt') + + canc_col_name= "improve_sample_id" + # level_map encodes the relationship btw the column and gene identifier system + level_map = {"Ensembl": 0, "Entrez": 1, "Gene_Symbol": 2} + header = [i for i in range(len(level_map))] + + df = pd.read_csv(gene_expression_file_path, sep=sep, index_col=0, header=header) + + df.index.name = canc_col_name # assign index name + df = set_col_names_in_multilevel_dataframe(df, level_map, gene_system_identifier) + if verbose: + print(f"Gene expression data: {df.shape}") + # print(df.dtypes) + # print(df.dtypes.value_counts()) + return df + + def load_cell_mutation_data(self, + data_dir, + gene_system_identifier: Union[str, List[str]]="Gene_Symbol", + sep: str="\t", verbose: bool=True) -> pd.DataFrame: + """ + Returns gene expression data. + + Args: + gene_system_identifier (str or list of str): gene identifier system to use + options: "Entrez", "Gene_Symbol", "Ensembl", "all", or any list + combination of ["Entrez", "Gene_Symbol", "Ensembl"] + + Returns: + pd.DataFrame: dataframe with the omic data + """ + cell_mutation_file_path = os.path.join(data_dir, 'cancer_mutation_count.tsv') + if self.version=='benchmark-data-imp-2023': + cell_mutation_file_path = os.path.join(data_dir, 'cancer_mutation_count.txt') + canc_col_name= "improve_sample_id" + # level_map encodes the relationship btw the column and gene identifier system + level_map = {"Ensembl": 0, "Entrez": 1, "Gene_Symbol": 2} + header = [i for i in range(len(level_map))] + + df = pd.read_csv(cell_mutation_file_path, sep=sep, index_col=0, header=header) + + df.index.name = canc_col_name # assign index name + df = set_col_names_in_multilevel_dataframe(df, level_map, gene_system_identifier) + if verbose: + print(f"cell mutation data: {df.shape}") + # print(df.dtypes) + # print(df.dtypes.value_counts()) + return df + + + def load_landmark_genes(self, data_path): + genes = pd.read_csv(os.path.join(data_path, 'landmark_genes'), header=None) + genes = genes.values.ravel().tolist() + return genes + +def set_col_names_in_multilevel_dataframe( + df: pd.DataFrame, + level_map: dict, + gene_system_identifier: Union[str, List[str]]="Gene_Symbol") -> pd.DataFrame: + """ Util function that supports loading of the omic data files. + Returns the input dataframe with the multi-level column names renamed as + specified by the gene_system_identifier arg. + + Args: + df (pd.DataFrame): omics dataframe + level_map (dict): encodes the column level and the corresponding identifier systems + gene_system_identifier (str or list of str): gene identifier system to use + options: "Entrez", "Gene_Symbol", "Ensembl", "all", or any list + combination of ["Entrez", "Gene_Symbol", "Ensembl"] + + Returns: + pd.DataFrame: the input dataframe with the specified multi-level column names + """ + df = df.copy() + + level_names = list(level_map.keys()) + level_values = list(level_map.values()) + n_levels = len(level_names) + + if isinstance(gene_system_identifier, list) and len(gene_system_identifier) == 1: + gene_system_identifier = gene_system_identifier[0] + + # print(gene_system_identifier) + # import pdb; pdb.set_trace() + if isinstance(gene_system_identifier, str): + if gene_system_identifier == "all": + df.columns = df.columns.rename(level_names, level=level_values) # assign multi-level col names + else: + df.columns = df.columns.get_level_values(level_map[gene_system_identifier]) # retian specific column level + else: + assert len(gene_system_identifier) <= n_levels, f"'gene_system_identifier' can't contain more than {n_levels} items." + set_diff = list(set(gene_system_identifier).difference(set(level_names))) + assert len(set_diff) == 0, f"Passed unknown gene identifiers: {set_diff}" + kk = {i: level_map[i] for i in level_map if i in gene_system_identifier} + # print(list(kk.keys())) + # print(list(kk.values())) + df.columns = df.columns.rename(list(kk.keys()), level=kk.values()) # assign multi-level col names + drop_levels = list(set(level_map.values()).difference(set(kk.values()))) + df = df.droplevel(level=drop_levels, axis=1) + return df + + +def remove_smiles_with_noneighbor_frags(smiles_df): + + remove_smiles=[] + for i in smiles_df.index: + smiles = smiles_df.loc[i, 'smiles'] + has_atoms_wothout_neighbors = check_for_atoms_without_neighbors(smiles) + if has_atoms_wothout_neighbors: + remove_smiles.append(smiles) + + smiles_df = smiles_df[~smiles_df.smiles.isin(remove_smiles)] + smiles_df.dropna(inplace=True) + smiles_df.reset_index(drop=True, inplace=True) + + return smiles_df + +def check_for_atoms_without_neighbors(smiles): + + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + frags = Chem.GetMolFrags(mol, asMols=True) + frag_atoms = [i.GetNumAtoms() for i in frags] + has_atoms_wothout_neighbors = any([i==1 for i in frag_atoms]) + + + return has_atoms_wothout_neighbors + + + +def load_generic_expression_data(file_path, + gene_system_identifier = "Gene_Symbol", + sep: str="\t", + verbose: bool=True) -> pd.DataFrame: + """ + Returns gene expression data. + + Args: + gene_system_identifier (str or list of str): gene identifier system to use + options: "Entrez", "Gene_Symbol", "Ensembl", "all", or any list + combination of ["Entrez", "Gene_Symbol", "Ensembl"] + + Returns: + pd.DataFrame: dataframe with the omic data + """ + + + canc_col_name= "improve_sample_id" + # level_map encodes the relationship btw the column and gene identifier system + level_map = {"Ensembl": 0, "Entrez": 1, "Gene_Symbol": 2} + header = [i for i in range(len(level_map))] + + df = pd.read_csv(file_path, sep=sep, index_col=0, header=header) + + df.index.name = canc_col_name # assign index name + df = set_col_names_in_multilevel_dataframe(df, level_map, gene_system_identifier) + return df diff --git a/model4/des_finetune.py b/model4/des_finetune.py new file mode 100644 index 0000000..e71a9f1 --- /dev/null +++ b/model4/des_finetune.py @@ -0,0 +1,270 @@ +from data_utils import Downloader, DataProcessor +from data_utils import add_smiles +from gnn_utils import create_data_list +from torch_geometric.loader import DataLoader +from model import Model +import torch +import torch.nn as nn +from gnn_utils import EarlyStopping +from gnn_utils import test_fn +import os +import argparse +import pandas as pd +from gnn_utils import CreateData +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + +from torch_geometric.data import Data +from tqdm import tqdm +import torch.nn as nn +import matplotlib.pyplot as plt +from rdkit import Chem +from rdkit.Chem import AllChem +import torch +import torch.nn.functional as F +from rdkit import Chem +from mordred import Calculator, descriptors +from sklearn.preprocessing import StandardScaler +import numpy as np +from sklearn.model_selection import train_test_split +import pickle + + +# Pretrained model +from torch.nn import Linear +class DescriptorEncoder(nn.Module): + def __init__(self, n_descriptors): + super(DescriptorEncoder, self).__init__() + + self.fc1 = Linear( n_descriptors, 1024) + self.fc2 = Linear( 1024, 512) + self.fc3 = Linear( 512, 256) + self.do1 = nn.Dropout(p = 0.1) + self.do2 = nn.Dropout(p = 0.1) + self.act1 = nn.ReLU() + self.act2 = nn.ReLU() + self.out = Linear(256, 1) + + def forward(self, data): + fp = data.fp + + e = self.do1(self.act1(self.fc1(fp))) + e = self.do2(self.act2(self.fc2(e))) + e = self.fc3(e) + out = self.out(e) + + return out + +with open('feature_names.pkl', 'rb') as f: + feature_names = pickle.load(f) + +n_descriptors = len(feature_names) + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +modelpt = DescriptorEncoder(n_descriptors) +modelpt.to(device); + + + + + + + +parser = argparse.ArgumentParser(prog='ProgramName', description='What the program does') +parser.add_argument('--metric', default='auc', help='') +parser.add_argument('--run_id', default='0', help='') +parser.add_argument('--epochs', default=100, help='') +parser.add_argument('--batch_size', type=int, default=64, help='') +parser.add_argument('--data_split_seed', default=-10, help='') +parser.add_argument('--data_split_id', default=0, help='') +parser.add_argument('--encoder_type', default='descriptor', help='') +parser.add_argument('--data_path', default='Descriptorenc/cmp_ctrpv2/Data', help='') +parser.add_argument('--data_type', default='CTRPv2', help='') +parser.add_argument('--data_version', default='benchmark-data-pilot1', help='benchmark-data-imp-2023 or benchmark-data-pilot1') +parser.add_argument('--out_dir', default='cmp_ctrpv2', help='') +parser.add_argument('--feature_path', type=str, default='../drug_features_pilot1.csv', help='') +parser.add_argument('--scale_gexp', type=bool, default=False, help='') + +args = parser.parse_args('') + + +pc = DataProcessor(args.data_version) + +metric = args.metric +bs = args.batch_size +lr = 1e-4 +n_epochs = int(args.epochs) +out_dir = args.out_dir +run_id = args.run_id +data_split_seed = int(args.data_split_seed) +encoder_type = args.encoder_type +data_type = args.data_type +data_split_id = args.data_split_id + +out_dir = os.path.join(out_dir, run_id ) + + +# os.makedirs( out_dir, exist_ok=True ) +# os.makedirs( args.data_path, exist_ok=True ) +# ckpt_path = os.path.join(out_dir, 'best.pt') + +# dw = Downloader(args.data_version) +# dw.download_candle_data(data_type=data_type, split_id=data_split_id, data_dest=args.data_path) +# dw.download_deepttc_vocabs(data_dest=args.data_path) + + + +train = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='train', response_type=metric, sep="\t", + dropna=True) + +val = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='val', response_type=metric, sep="\t", + dropna=True) + +test = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='test', response_type=metric, sep="\t", + dropna=True) + +smiles_df = pc.load_smiles_data(args.data_path) + +train = add_smiles(smiles_df=smiles_df, df=train, metric=metric) +val = add_smiles(smiles_df=smiles_df, df=val, metric=metric) +test = add_smiles(smiles_df=smiles_df, df=test, metric=metric) +df_all = pd.concat([train, val, test], axis=0) +df_all.reset_index(drop=True, inplace=True) + +# args.feature_path +# len(feature_names) +# len(feature_names) + +if data_split_seed > -1: + print("using random splitting") + train, val, test = split_df(df=df_all, seed=data_split_seed) +else: + print("using predefined splits") + + +gene_exp = pc.load_gene_expression_data(args.data_path) + +lm = pc.load_landmark_genes(args.data_path) +lm = list(set(lm).intersection(gene_exp.columns)) +gexp = gene_exp.loc[:, lm] + +if args.scale_gexp: + scgexp = StandardScaler() + gexp.loc[:,:] = scgexp.fit_transform(gexp) + + +n_descriptors=None +features=None +if args.feature_path: + print("feature path exists") + features = pd.read_csv(args.feature_path) + # n_descriptors = features.shape[1] - 1 + # feature_names = features.drop(['smiles'], axis=1).columns.tolist() + + test = pd.merge(test, features, on='smiles', how='left') + train = pd.merge(train, features, on='smiles', how='left') + val = pd.merge(val, features, on='smiles', how='left') + + sc = StandardScaler() + train.loc[:, feature_names] = sc.fit_transform(train.loc[:, feature_names]) + test.loc[:, feature_names] = sc.transform(test.loc[:, feature_names]) + val.loc[:, feature_names] = sc.transform(val.loc[:, feature_names]) +else: + feature_names=None + +print(args.feature_path) +n_descriptors = len(feature_names) +print(metric, n_descriptors, len(feature_names)) + +data_creater = CreateData(gexp=gexp, metric=metric, encoder_type=encoder_type, data_path=args.data_path, feature_names=feature_names) + +train_ds = data_creater.create_data(train) +val_ds = data_creater.create_data(val) +test_ds = data_creater.create_data(test) +train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=True) +val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, drop_last=False) +test_loader = DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False) + +model = Model(gnn_features = 65, n_descriptors=n_descriptors, encoder_type=encoder_type).to(device) +batch = next(iter(train_loader)) + + +model.drug_encoder.load_state_dict(modelpt.state_dict(), strict=False) + + +for name, p in model.drug_encoder.named_parameters(): + if 'out' not in name: + p.requires_grad = False + +adam = torch.optim.Adam(model.parameters(), lr = lr ) +optimizer = adam +ckpt_path = 'des_finetune.pt' + +early_stopping = EarlyStopping(patience = n_epochs, verbose=True, chkpoint_name = ckpt_path) +criterion = nn.MSELoss() + + +# train the model +hist = {"train_rmse":[], "val_rmse":[]} +for epoch in range(0, n_epochs): + model.train() + loss_all = 0 + for data in train_loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data) + output = output.reshape(-1,) + + loss = criterion(output, data.y) + loss.backward() + optimizer.step() + + + # train_rmse = gnn_utils.test_fn(train_loader, model, device) + val_rmse, _, _ = test_fn(val_loader, model, device) + early_stopping(val_rmse, model) + + if early_stopping.early_stop: + print("Early stopping") + break + + # hist["train_rmse"].append(train_rmse) + hist["val_rmse"].append(val_rmse) + # print(f'Epoch: {epoch}, Train_rmse: {train_rmse:.3}, Val_rmse: {val_rmse:.3}') + print(f'Epoch: {epoch}, Val_rmse: {val_rmse:.3}') + +# print(f"training completed at {datetime.datetime.now()}") + +model.load_state_dict(torch.load(ckpt_path)) + +test_rmse, true, pred = test_fn(test_loader, model, device) +test['true'] = true +test['pred'] = pred + +if args.feature_path: + test = test[['improve_sample_id', 'smiles', 'improve_chem_id', 'auc', 'true', 'pred']] + +test.to_csv( os.path.join('test_predictions_ft.csv'), index=False ) + + + + + + + + + + + + + + + + +# if __name__ == "__main__": + + diff --git a/model4/des_pretrain.py b/model4/des_pretrain.py new file mode 100644 index 0000000..7d2ae39 --- /dev/null +++ b/model4/des_pretrain.py @@ -0,0 +1,201 @@ +from data_utils import Downloader, DataProcessor +from data_utils import add_smiles +from gnn_utils import create_data_list +from torch_geometric.loader import DataLoader +from model import Model +import torch +import torch.nn as nn +from gnn_utils import EarlyStopping +# dw = Downloader('benchmark-data-imp-2023') +from gnn_utils import test_fn +import os +import argparse +import pandas as pd +from gnn_utils import CreateData +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + +from torch_geometric.data import Data +from tqdm import tqdm +import torch.nn as nn +import matplotlib.pyplot as plt +from rdkit import Chem +from rdkit.Chem import AllChem +import torch +import torch.nn.functional as F +from rdkit import Chem +from mordred import Calculator, descriptors +from sklearn.preprocessing import StandardScaler +import numpy as np +from sklearn.model_selection import train_test_split +import pickle + +def split_df(df, seed): + + train, test = train_test_split(df, random_state=seed, test_size=0.2) + val, test = train_test_split(test, random_state=seed, test_size=0.5) + + train.reset_index(drop=True, inplace=True) + val.reset_index(drop=True, inplace=True) + test.reset_index(drop=True, inplace=True) + + return train, val, test + + + +# des = pd.read_csv('dataset.csv') +# des +# calc = Calculator(descriptors, ignore_3D=True) +# mols = [Chem.MolFromSmiles(i) for i in des.SMILES] + +df_mdm = pd.read_csv('mdm_somas.csv') +df_mdm = df_mdm.iloc[:, :-2] + +rem=[] +for i in df_mdm.columns: + try: + df_mdm[i].astype(float) + except: + rem.append(i) + +feature_names = list(set(df_mdm.columns).difference(rem)) +features = pd.read_csv('../drug_features_pilot1.csv') +feature_names = list(set(feature_names).intersection(features.columns[1:])) + + +df_mdm = df_mdm.loc[:, feature_names] +# 'GATS1Z' in df_mdm.columns +feature_names = list(set(df_mdm.columns).difference(['GATS1Z'])) +sc = StandardScaler() +np.where(np.isnan(sc.fit_transform(df_mdm))) +# calc = Calculator(descriptors.Autocorrelation, ignore_3D=True) +# mols = [Chem.MolFromSmiles(i) for i in des.SMILES] + +# df_des = calc.pandas(mols) +# df_des['smiles'] = des['SMILES'] +# df_des['mfrags'] = [len(Chem.GetMolFrags(i)) for i in mols] +# df_des.to_csv('des_autoc.csv', index=False) +# df_des = pd.read_csv('des_autoc.csv') +# df_des.shape + + + +with open('feature_names.pkl', 'wb') as f: + pickle.dump(feature_names, f) + +sc = StandardScaler() +df_mdm.loc[:, feature_names] = sc.fit_transform(df_mdm.loc[:, feature_names]) +df_mdm.shape + + + +train, test = train_test_split(df_mdm, test_size=.2) +val, test = train_test_split(test, test_size=.5) + +target_name = 'GATS1Z' +'GATS1Z' in feature_names +def create_descriptor_data(df): + + df.reset_index(drop=True, inplace=True) + # data = df.copy() + # data.set_index('smiles', inplace=True) + + data_list = [] + for i in tqdm(range(df.shape[0])): + # smiles = data.loc[i, 'smiles'] + y = df.loc[i, target_name] + # improve_sample_id = data.loc[i, 'improve_sample_id'] + # feature_list = self.features.loc[smiles, :].values.tolist() + feature_list = df.loc[i, feature_names].values.tolist() + # ge = self.gexp.loc[improve_sample_id, :].values.tolist() + + data = Data(fp=torch.tensor([feature_list], dtype=torch.float), + y=torch.tensor([y],dtype=torch.float),) + data_list.append(data) + + return data_list + + + +# 'GATS1Z' in data.columns +bs = 64 +lr = 1e-4 +n_epochs = 500 +train_ds = create_descriptor_data(train) +val_ds = create_descriptor_data(val) +test_ds = create_descriptor_data(test) +train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=True) +val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, drop_last=False) +test_loader = DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False) +n_descriptors = len(feature_names) +from torch.nn import Linear +class DescriptorEncoder(nn.Module): + def __init__(self, n_descriptors): + super(DescriptorEncoder, self).__init__() + + self.fc1 = Linear( n_descriptors, 1024) + self.fc2 = Linear( 1024, 512) + self.fc3 = Linear( 512, 256) + self.do1 = nn.Dropout(p = 0.1) + self.do2 = nn.Dropout(p = 0.1) + self.act1 = nn.ReLU() + self.act2 = nn.ReLU() + self.out = Linear(256, 1) + + def forward(self, data): + fp = data.fp + + e = self.do1(self.act1(self.fc1(fp))) + e = self.do2(self.act2(self.fc2(e))) + e = self.fc3(e) + out = self.out(e) + + return out + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = DescriptorEncoder(n_descriptors) +model.to(device); + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +adam = torch.optim.Adam(model.parameters(), lr = lr ) +optimizer = adam + +ckpt_path = 'pretrain.pt' +early_stopping = EarlyStopping(patience = 20, verbose=True, chkpoint_name = ckpt_path) +criterion = nn.MSELoss() + + +hist = {"train_rmse":[], "val_rmse":[]} +for epoch in range(0, n_epochs): + model.train() + loss_all = 0 + for data in train_loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data) + output = output.reshape(-1,) + + loss = criterion(output, data.y) + loss.backward() + optimizer.step() + + + # train_rmse = gnn_utils.test_fn(train_loader, model, device) + val_rmse, _, _ = test_fn(val_loader, model, device) + early_stopping(val_rmse, model) + + if early_stopping.early_stop: + print("Early stopping") + break + + # hist["train_rmse"].append(train_rmse) + hist["val_rmse"].append(val_rmse) + # print(f'Epoch: {epoch}, Train_rmse: {train_rmse:.3}, Val_rmse: {val_rmse:.3}') + print(f'Epoch: {epoch}, Val_rmse: {val_rmse:.3}') + +# print(f"training completed at {datetime.datetime.now()}") + +#model.load_state_dict(torch.load(ckpt_path)) + + diff --git a/model4/gnn_utils.py b/model4/gnn_utils.py new file mode 100644 index 0000000..f7bf035 --- /dev/null +++ b/model4/gnn_utils.py @@ -0,0 +1,498 @@ + +from rdkit import Chem +from torch_geometric.data import Data +import matplotlib.pyplot as plt +from rdkit import Chem +import numpy as np +import torch +import pandas as pd +import os +import time +from tqdm import tqdm +from sklearn.metrics import mean_squared_error +from sklearn.metrics import mean_absolute_error +from sklearn.metrics import r2_score +from scipy.stats import spearmanr +import random +import pickle, gzip +from subword_nmt.apply_bpe import BPE +import codecs +from rdkit.Chem import AllChem +# import config + + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(seed) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + +def one_of_k_encoding(x, allowable_set): + if x not in allowable_set: + raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) + return list(map(lambda s: x == s, allowable_set)) + +def one_of_k_encoding_unk(x, allowable_set): + """Maps inputs not in the allowable set to the last element.""" + if x not in allowable_set: + x = allowable_set[-1] + return list(map(lambda s: x == s, allowable_set)) + +def get_intervals(l): + """For list of lists, gets the cumulative products of the lengths""" + intervals = len(l) * [0] + # Initalize with 1 + intervals[0] = 1 + for k in range(1, len(l)): + intervals[k] = (len(l[k]) + 1) * intervals[k - 1] + return intervals + +def safe_index(l, e): + """Gets the index of e in l, providing an index of len(l) if not found""" + try: + return l.index(e) + except: + return len(l) + +def best_fit_slope_and_intercept(xs,ys): + m = (((np.mean(xs)*np.mean(ys)) - np.mean(xs*ys)) / + ((np.mean(xs)*np.mean(xs)) - np.mean(xs*xs))) + + b = np.mean(ys) - m*np.mean(xs) + + return m, b + +possible_atom_list = [ + 'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu', + 'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn' +] +possible_numH_list = [0, 1, 2, 3, 4] +possible_valence_list = [0, 1, 2, 3, 4, 5, 6] +possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3] +possible_hybridization_list = [ + Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2 +] +possible_number_radical_e_list = [0, 1, 2] +possible_chirality_list = ['R', 'S'] + +reference_lists = [ + possible_atom_list, possible_numH_list, possible_valence_list, + possible_formal_charge_list, possible_number_radical_e_list, + possible_hybridization_list, possible_chirality_list +] + +intervals = get_intervals(reference_lists) + + +def get_feature_list(atom): + features = 6 * [0] + features[0] = safe_index(possible_atom_list, atom.GetSymbol()) + features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs()) + features[2] = safe_index(possible_valence_list, atom.GetImplicitValence()) + features[3] = safe_index(possible_formal_charge_list, atom.GetFormalCharge()) + features[4] = safe_index(possible_number_radical_e_list, + atom.GetNumRadicalElectrons()) + features[5] = safe_index(possible_hybridization_list, atom.GetHybridization()) + return features + +def features_to_id(features, intervals): + """Convert list of features into index using spacings provided in intervals""" + id = 0 + for k in range(len(intervals)): + id += features[k] * intervals[k] + + # Allow 0 index to correspond to null molecule 1 + id = id + 1 + return id + +def id_to_features(id, intervals): + features = 6 * [0] + + # Correct for null + id -= 1 + + for k in range(0, 6 - 1): + # print(6-k-1, id) + features[6 - k - 1] = id // intervals[6 - k - 1] + id -= features[6 - k - 1] * intervals[6 - k - 1] + # Correct for last one + features[0] = id + return features + +def atom_to_id(atom): + """Return a unique id corresponding to the atom type""" + features = get_feature_list(atom) + return features_to_id(features, intervals) + +def atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=False): + if bool_id_feat: + return np.array([atom_to_id(atom)]) + else: + from rdkit import Chem + + results = one_of_k_encoding_unk(atom.GetSymbol(),['Ag','Al','As','B','Br','C','Ca','Cd','Cl','Cu','F', + 'Fe','Ge','H','Hg','I','K','Li','Mg','Mn','N','Na', + 'O','P','Pb','Pt','S','Se','Si','Sn','Sr','Tl','Zn', + 'Unknown'])\ + + one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \ + one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \ + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \ + one_of_k_encoding_unk(atom.GetHybridization(), [ + Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType. + SP3D, Chem.rdchem.HybridizationType.SP3D2 + ]) + [atom.GetIsAromatic()] + + if not explicit_H: + results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(), + [0, 1, 2, 3, 4]) + if use_chirality: + try: + results = results + one_of_k_encoding_unk(atom.GetProp('_CIPCode'), + ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] + except: + results = results + [False, False] + [atom.HasProp('_ChiralityPossible')] + + return np.array(results) + +def bond_features(bond, use_chirality=False): + from rdkit import Chem + bt = bond.GetBondType() + bond_feats = [ + bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, + bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, + bond.GetIsConjugated(), + bond.IsInRing()] + + if use_chirality: + bond_feats = bond_feats + one_of_k_encoding_unk(str(bond.GetStereo()), + ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]) + return np.array(bond_feats) + + +def get_bond_pair(mol): + bonds = mol.GetBonds() + res = [[],[]] + for bond in bonds: + res[0] += [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] + res[1] += [bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()] + return res + + + +class SmilesTokenizer: + def __init__(self, vocab_path=None, + subword_path=None): + + # vocab_path = "./DeepTTC/ESPF/drug_codes_chembl_freq_1500.txt" + # sub_csv = pd.read_csv("./DeepTTC/ESPF/subword_units_map_chembl_freq_1500.csv") + + vocab_path = vocab_path + sub_csv = pd.read_csv(subword_path) + + + + bpe_codes_drug = codecs.open(vocab_path) + self.dbpe = BPE(bpe_codes_drug, merges=-1, separator='') + + idx2word_d = sub_csv['index'].values + self.words2idx_d = dict(zip(idx2word_d, range(0, len(idx2word_d)))) + + + self.max_d = 50 + + + def tokenize(self, smile): + + t1 = self.dbpe.process_line(smile).split() # split + try: + i1 = np.asarray( [ self.words2idx_d[i] for i in t1] ) # index + except: + i1 = np.array([0]) + + l = len(i1) + if l < self.max_d: + i = np.pad(i1, (0, self.max_d - l), 'constant', constant_values=0) + input_mask = ([1] * l) + ([0] * (self.max_d - l)) + else: + i = i1[:self.max_d] + input_mask = [1] * self.max_d + + return i, np.asarray(input_mask) + + +def GNNData(smiles, y, ge): + mol = Chem.MolFromSmiles(smiles) + atoms = mol.GetAtoms() + bonds = mol.GetBonds() + node_f= [atom_features(atom) for atom in atoms] + edge_index = get_bond_pair(mol) + + edge_attr=[] + for bond in bonds: + edge_attr.append(bond_features(bond, use_chirality=False)) + edge_attr.append(bond_features(bond, use_chirality=False)) + + data = Data(x=torch.tensor(node_f, dtype=torch.float), + edge_index=torch.tensor(edge_index, dtype=torch.long), + edge_attr=torch.tensor(edge_attr,dtype=torch.float), + y=torch.tensor([y],dtype=torch.float), + ge = torch.tensor([ge],dtype=torch.float), + + ) + return data + + +def TransformerData(tokens, y, ge): + + data = Data(tokens=torch.tensor([tokens], dtype=torch.long), + y=torch.tensor([y],dtype=torch.float), + ge = torch.tensor([ge],dtype=torch.float),) + return data + +def MorganFPData(fp, y, ge): + + data = Data(fp=torch.tensor([fp], dtype=torch.float), + y=torch.tensor([y],dtype=torch.float), + ge = torch.tensor([ge],dtype=torch.float),) + return data + +def DescriptorData(fp, y, ge): + + data = Data(fp=torch.tensor([fp], dtype=torch.float), + y=torch.tensor([y],dtype=torch.float), + ge = torch.tensor([ge],dtype=torch.float),) + return data + +def create_data_list(data, gexp, metric='ic50'): + data_list = [] + for i in tqdm(range(data.shape[0])): + smiles = data.loc[i, 'smiles'] + y = data.loc[i, metric] + improve_sample_id = data.loc[i, 'improve_sample_id'] + ge = gexp.loc[improve_sample_id, :].values.tolist() + data_list.append(GNNData(smiles=smiles, y=y, ge=ge)) + return data_list + + + + +class CreateData: + def __init__(self, gexp, metric='ic50', encoder_type='gnn', data_path=None, feature_names=None): + + # vocab_path = "./DeepTTC/ESPF/drug_codes_chembl_freq_1500.txt" + # sub_csv = pd.read_csv("./DeepTTC/ESPF/subword_units_map_chembl_freq_1500.csv") + self.tokenizer = SmilesTokenizer(vocab_path = os.path.join(data_path, 'drug_codes_chembl_freq_1500.txt'), + subword_path = os.path.join(data_path, 'subword_units_map_chembl_freq_1500.csv') + ) + self.metric = metric + self.gexp = gexp + self.encoder_type=encoder_type + self.feature_names = feature_names + + # if isinstance(features, pd.DataFrame): + # self.features = features + # self.features.set_index('smiles', inplace=True) + + def create_data(self, data): + if self.encoder_type=='gnn': + print('creating gnm data') + return self.create_gnn_data(data) + elif self.encoder_type=='transformer': + print('creating transformer data') + return self.create_transformer_data(data) + elif self.encoder_type=='morganfp': + print('creating morganfp data') + return self.create_morganfp_data(data) + elif self.encoder_type=='descriptor': + print('creating descriptor data') + return self.create_descriptor_data(data) + + def create_gnn_data(self, data): + data_list = [] + for i in tqdm(range(data.shape[0])): + smiles = data.loc[i, 'smiles'] + y = data.loc[i, self.metric] + improve_sample_id = data.loc[i, 'improve_sample_id'] + ge = self.gexp.loc[improve_sample_id, :].values.tolist() + data_list.append(GNNData(smiles=smiles, y=y, ge=ge)) + return data_list + + + def create_transformer_data(self, data): + data_list = [] + for i in tqdm(range(data.shape[0])): + smiles = data.loc[i, 'smiles'] + y = data.loc[i, self.metric] + improve_sample_id = data.loc[i, 'improve_sample_id'] + ge = self.gexp.loc[improve_sample_id, :].values.tolist() + tokens, _ = self.tokenizer.tokenize(smiles) + data_list.append(TransformerData(tokens=tokens, y=y, ge=ge)) + + return data_list + + def create_morganfp_data(self, data): + data_list = [] + for i in tqdm(range(data.shape[0])): + smiles = data.loc[i, 'smiles'] + y = data.loc[i, self.metric] + improve_sample_id = data.loc[i, 'improve_sample_id'] + ge = self.gexp.loc[improve_sample_id, :].values.tolist() + + mol = Chem.MolFromSmiles(smiles) + fp = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024)) + + data_list.append(MorganFPData(fp=fp, y=y, ge=ge)) + + return data_list + + def create_descriptor_data(self, data): + data_list = [] + for i in tqdm(range(data.shape[0])): + smiles = data.loc[i, 'smiles'] + y = data.loc[i, self.metric] + improve_sample_id = data.loc[i, 'improve_sample_id'] + # feature_list = self.features.loc[smiles, :].values.tolist() + feature_list = data.loc[i, self.feature_names].values.tolist() + ge = self.gexp.loc[improve_sample_id, :].values.tolist() + + data_list.append(DescriptorData(fp=feature_list, y=y, ge=ge)) + + return data_list + + + + + + +def create_data(): + + train = pd.read_csv(config.data_dir+"train.csv") + val = pd.read_csv(config.data_dir+"val.csv") + test = pd.read_csv(config.data_dir+"test.csv") + + train.reset_index(drop=True, inplace=True) + val.reset_index(drop=True, inplace=True) + test.reset_index(drop=True, inplace=True) + + print("checking for duplicates") + if len(list(set(train.smiles.values).intersection(set(test.smiles.values)) )) == 0: + print("no duplicates in train and test") + + if len(list(set(train.smiles.values).intersection(set(val.smiles.values)) )) == 0: + print("no duplicates in train and valid") + + if len(list( set(test.smiles.values).intersection(set(val.smiles.values)) )) == 0: + print("no duplicates in test and valid") + print(" ") + + print(f"train set size = {train.shape}, unique smiles in the train set = {len(set(train.smiles.values))}") + print(f"train set size = {val.shape}, unique smiles in the valid set = {len(set(val.smiles.values))}") + print(f"train set size = {test.shape}, unique smiles in the test set = {len(set(test.smiles.values))}") + print(" ") + + print("creating train data") + train_X = create_data_list(train) + print("creating valid data") + val_X = create_data_list(val) + print("creating test data") + test_X = create_data_list(test) + + with gzip.open(config.gnn_data_dir+"train.pkl.gz", "wb") as f: + pickle.dump(train_X, f, protocol=4) + with gzip.open(config.gnn_data_dir+"val.pkl.gz", "wb") as f: + pickle.dump(val_X, f, protocol=4) + with gzip.open(config.gnn_data_dir+"test.pkl.gz", "wb") as f: + pickle.dump(test_X, f, protocol=4) + + + + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + def __init__(self, patience=7, verbose=False, delta=0, chkpoint_name = 'gnn_best.pt' ): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.chkpoint_name = chkpoint_name + + def __call__(self, val_loss, model): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + '''Saves model when validation loss decrease.''' + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + torch.save(model.state_dict(), self.chkpoint_name) + self.val_loss_min = val_loss + + +def test_fn(loader, model, device): + model.eval() + with torch.no_grad(): + target, predicted = [], [] + for data in loader: + data = data.to(device) + output = model(data) + pred = output + + target += list(data.y.cpu().numpy().ravel() ) + predicted += list(pred.cpu().numpy().ravel() ) + + return mean_squared_error(y_true=target, y_pred=predicted), target, predicted + + + +def get_results(db_name, loader, model, device): + + print(f"{db_name} results") + test_t, test_p = test_fn_plotting(loader, model, device) + + r2 = r2_score(y_pred = test_p, y_true = test_t) + rmse = mean_squared_error(y_pred = test_p, y_true = test_t)**.5 + sp = spearmanr(test_p, test_t)[0] + mae = mean_absolute_error(y_pred=test_p, y_true=test_t) + + print("r2: {0:.4f}".format(r2) ) + print("rmse: {0:.4f}".format(rmse) ) + print("sp: {0:.4f}".format(sp) ) + print("mae: {0:.4f}".format(mae) ) + + plt.figure() + plt.plot( test_t, test_p, 'o') + plt.xlabel("True (logS)", fontsize=15, fontweight='bold'); + plt.ylabel("Predicted (logS)", fontsize=15, fontweight='bold'); + plt.show() + diff --git a/model4/model.py b/model4/model.py new file mode 100644 index 0000000..573fa22 --- /dev/null +++ b/model4/model.py @@ -0,0 +1,275 @@ +import numpy as np +from rdkit import Chem +from torch_geometric.data import Data +from sklearn.metrics import mean_squared_error +import torch.nn as nn +import matplotlib.pyplot as plt +from rdkit import Chem +from rdkit.Chem import AllChem +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from rdkit.Chem import Draw +from sklearn.metrics import r2_score +from scipy.stats import spearmanr +import pandas as pd +from random import randrange +import itertools +import random +import os +from pickle import dump, load +from sklearn.metrics import mean_absolute_error +import pickle +import gzip, pickle +# from torch_geometric.data import DataLoader +import gnn_utils +# import gnn_model +# from gnn_model import GNN +# import config +import datetime + +import torch +import torch.nn as nn +from torch_geometric.nn import GCNConv +from torch_geometric.nn import global_add_pool, global_mean_pool +from torch_geometric.nn import EdgeConv +from torch.nn import Linear + +params = {'a1': 0, 'a2': 2, 'a3': 1, 'a4': 2, 'bs': 1, 'd1': 0.015105134306121593, 'd2': 0.3431295462686682, \ + 'd3': 0.602688496976768, 'd4': 0.9532038077650021, 'e1': 256.0, 'eact1': 0, 'edo1': 0.4813038851902818,\ + 'f1': 256.0, 'f2': 256.0, 'f3': 160.0, 'f4': 24.0, 'g1': 256.0, 'g2': 320.0, 'g21': 448.0,\ + 'g22': 512.0, 'gact1': 2, 'gact2': 2, 'gact21': 2, 'gact22': 0, 'gact31': 2, 'gact32': 1, 'gact33': 1,\ + 'gdo1': 0.9444250299450242, 'gdo2': 0.8341272742321129, 'gdo21': 0.7675340644596443,\ + 'gdo22': 0.21498171859119775, 'gdo31': 0.8236003195596049, 'gdo32': 0.6040220843354102,\ + 'gdo33': 0.21007469160431758, 'lr': 0, 'nfc': 0, 'ngl': 1, 'opt': 0} +act = {0: torch.nn.ReLU(), 1:torch.nn.SELU(), 2:torch.nn.Sigmoid()} + + + +from transformer import TransformerModel + +class Model(torch.nn.Module): + + def __init__(self, gnn_features, n_descriptors=None, encoder_type=None, n_genes=958): + super(Model, self).__init__() + + self.encoder_type= encoder_type + self.gnn_features = gnn_features + + if encoder_type=='gnn': + self.drug_encoder = GNNEncoder(gnn_features) + + elif encoder_type=='transformer': + args = {'vocab_size':2586, + 'masked_token_train': False, + 'finetune': False} + self.drug_encoder = TransformerModel(args) + elif encoder_type=='morganfp': + self.drug_encoder = MorganFPEncoder() + elif encoder_type=='descriptor': + self.drug_encoder = DescriptorEncoder(n_descriptors) + + # self.out2 = Linear(int(params['f2']), 1) + # self.out3 = Linear(int(params['f3']), 1) + # self.out4 = Linear(int(params['f4']), 1) + + self.dropout1 = nn.Dropout(p = params['d1'] ) + self.act1 = act[params['a1']] + + self.dropout2 = nn.Dropout(p = params['d2'] ) + self.act2 = act[params['a2']] + + self.transformer_lin = Linear(512,256) + + self.gexp_lin1 = Linear(n_genes, 958) + self.gexp_lin2 = Linear(958, 256) + + self.cat1 = Linear(512, 256) + self.cat2 = Linear(256, 128) + self.out = Linear(128, 1) + + + + def forward(self, data): + # node_x, edge_x, edge_index = data.x, data.edge_attr, data.edge_index + + if self.encoder_type in ['gnn', 'morganfp', 'descriptor']: + drug = self.drug_encoder(data) + + elif self.encoder_type == 'transformer': + _,_, drug = self.drug_encoder(data) + drug = self.transformer_lin(drug) + + gexp = data.ge + gexp = self.gexp_lin1(gexp) + gexp = self.gexp_lin2(gexp) + + drug_gene = torch.cat((drug, gexp), 1) + + + x3 = self.dropout1(self.act1(self.cat1( drug_gene ))) + x3 = self.dropout2(self.act2(self.cat2( x3 ))) + x3 = self.out(x3) + return x3 + + +class MorganFPEncoder(nn.Module): + def __init__(self,): + super(MorganFPEncoder, self).__init__() + + self.fc1 = Linear( 1024, 1024) + self.fc2 = Linear( 1024, 512) + self.fc3 = Linear( 512, 256) + self.do1 = nn.Dropout(p = 0.1) + self.do2 = nn.Dropout(p = 0.1) + self.act1 = nn.ReLU() + self.act2 = nn.ReLU() + + def forward(self, data): + fp = data.fp + + e = self.do1(self.act1(self.fc1(fp))) + e = self.do2(self.act2(self.fc2(e))) + e = self.fc3(e) + + return e + +class DescriptorEncoder(nn.Module): + def __init__(self, n_descriptors): + super(DescriptorEncoder, self).__init__() + + self.fc1 = Linear( n_descriptors, 1024) + self.fc2 = Linear( 1024, 512) + self.fc3 = Linear( 512, 256) + self.do1 = nn.Dropout(p = 0.1) + self.do2 = nn.Dropout(p = 0.1) + self.act1 = nn.ReLU() + self.act2 = nn.ReLU() + + def forward(self, data): + fp = data.fp + + e = self.do1(self.act1(self.fc1(fp))) + e = self.do2(self.act2(self.fc2(e))) + e = self.fc3(e) + + return e + +class GNNEncoder(torch.nn.Module): + + def __init__(self, n_features): + super(GNNEncoder, self).__init__() + self.n_features = n_features + self.gcn1 = GCNConv(self.n_features, int(params['g1']), cached=False) + self.gcn2 = GCNConv( int(params['g1']), int(params['g2']), cached=False) + self.gcn21 = GCNConv( int(params['g2']), int(params['g21']), cached=False) + self.gcn22 = GCNConv( int(params['g21']), int(params['g22']), cached=False) + + self.gcn31 = GCNConv(int(params['g2']), int(params['e1']), cached=False) + self.gcn32 = GCNConv(int(params['g21']), int(params['e1']), cached=False) + self.gcn33 = GCNConv(int(params['g22']), int(params['e1']), cached=False) + + self.gdo1 = nn.Dropout(p = params['gdo1'] ) + self.gdo2 = nn.Dropout(p = params['gdo2'] ) + self.gdo31 = nn.Dropout(p = params['gdo31'] ) + self.gdo21 = nn.Dropout(p = params['gdo21'] ) + self.gdo32 = nn.Dropout(p = params['gdo32'] ) + self.gdo22 = nn.Dropout(p = params['gdo22'] ) + self.gdo33 = nn.Dropout(p = params['gdo33'] ) + + self.gact1 = act[params['gact1'] ] + self.gact2 = act[params['gact2'] ] + self.gact31 = act[params['gact31']] + self.gact21 = act[params['gact21'] ] + self.gact32 = act[params['gact32'] ] + self.gact22 = act[params['gact22'] ] + self.gact33 = act[params['gact33'] ] + + self.ecn1 = EdgeConv(nn = nn.Sequential(nn.Linear(n_features*2, int(params['e1']) ), + nn.ReLU(), + nn.Linear( int(params['e1']) , int(params['f1']) ),)) + + self.edo1 = nn.Dropout(p = params['edo1'] ) + self.eact1 = act[params['eact1'] ] + + + self.fc1 = Linear( int(params['e1'])+ int(params['f1']), int(params['f1'])) + self.dropout1 = nn.Dropout(p = params['d1'] ) + self.act1 = act[params['a1']] + + self.fc2 = Linear(int(params['f1']), int(params['f2'])) + self.dropout2 = nn.Dropout(p = params['d2'] ) + self.act2 = act[params['a2']] + + self.fc3 = Linear(int(params['f2']), int(params['f3'])) + self.dropout3 = nn.Dropout(p = params['d3'] ) + self.act3 = act[params['a3']] + + self.fc4 = Linear(int(params['f3']), int(params['f4'])) + self.dropout4 = nn.Dropout(p = params['d4'] ) + self.act4 = act[params['a4']] + + + + + def forward(self, data): + node_x, edge_x, edge_index = data.x, data.edge_attr, data.edge_index + + + + x1 = self.gdo1(self.gact1( self.gcn1( node_x, edge_index ) ) ) + x1 = self.gdo2(self.gact2(self.gcn2(x1, edge_index)) ) + x1 = self.gdo21(self.gact21(self.gcn21(x1, edge_index)) ) + x1 = self.gdo32(self.gact32(self.gcn32(x1, edge_index)) ) + + x2 = self.edo1(self.eact1(self.ecn1(node_x, edge_index)) ) + x3 = torch.cat((x1,x2), 1) + x3 = global_add_pool(x3, data.batch) + + + x3 = self.act1(self.fc1( x3 )) # 256 + + return x3 + + + +# class Model(torch.nn.Module): + +# def __init__(self, gnn_features): +# super(Model, self).__init__() + +# self.gnn_features = gnn_features +# self.drug_encoder = GNNEncoder(gnn_features) + +# self.dropout1 = nn.Dropout(p = params['d1'] ) +# self.act1 = act[params['a1']] + +# self.dropout2 = nn.Dropout(p = params['d2'] ) +# self.act2 = act[params['a2']] + + +# self.gexp_lin1 = Linear(958, 958) +# self.gexp_lin2 = Linear(958, 256) + +# self.cat1 = Linear(512, 256) +# self.cat2 = Linear(256, 128) +# self.out = Linear(128, 1) + + + +# def forward(self, data): +# # node_x, edge_x, edge_index = data.x, data.edge_attr, data.edge_index + +# drug = self.drug_encoder(data) + +# gexp = data.ge +# gexp = self.gexp_lin1(gexp) +# gexp = self.gexp_lin2(gexp) + +# drug_gene = torch.cat((drug, gexp), 1) + + +# x3 = self.dropout1(self.act1(self.cat1( drug_gene ))) +# x3 = self.dropout2(self.act2(self.cat2( x3 ))) +# x3 = self.out(x3) +# return x3 diff --git a/model4/run.sh b/model4/run.sh new file mode 100644 index 0000000..bedf879 --- /dev/null +++ b/model4/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +#python run_models.py --config config_ctrpv2.yaml +python run_models.py --config configs/config_ccle.yaml diff --git a/model4/run_des.py b/model4/run_des.py new file mode 100644 index 0000000..401152f --- /dev/null +++ b/model4/run_des.py @@ -0,0 +1,35 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(3) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(10): + + + os.system(f"python train.py --encoder_type descriptor --out_dir Descriptorenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Descriptorenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + --feature_path ../drug_features_pilot1.csv") + # break diff --git a/model4/run_gnn.py b/model4/run_gnn.py new file mode 100644 index 0000000..d698dda --- /dev/null +++ b/model4/run_gnn.py @@ -0,0 +1,34 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(3) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(10): + + + os.system(f"python train.py --batch_size 64 --encoder_type gnn --out_dir GNNenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path GNNenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + diff --git a/model4/run_models.py b/model4/run_models.py new file mode 100644 index 0000000..614d91c --- /dev/null +++ b/model4/run_models.py @@ -0,0 +1,44 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_id) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(10): + + + os.system(f"python train.py --encoder_type gnn --out_dir GNNenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path GNNenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + + os.system(f"python train.py --encoder_type transformer --out_dir Trnsfenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Trnsfenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + + os.system(f"python train.py --encoder_type morganfp --out_dir Morganfpenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Morganfpenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + + os.system(f"python train.py --encoder_type descriptor --out_dir Descriptorenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Descriptorenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + --feature_path ../drug_features_pilot1.csv") + # break diff --git a/model4/run_models_ctrpv2.py b/model4/run_models_ctrpv2.py new file mode 100644 index 0000000..f4b7a4e --- /dev/null +++ b/model4/run_models_ctrpv2.py @@ -0,0 +1,37 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_id) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(10): + + + os.system(f"python train.py --encoder_type gnn --out_dir GNNenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path GNNenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + + #os.system(f"python train.py --encoder_type morganfp --out_dir Morganfpenc/{out_dir} --data_version benchmark-data-pilot1 \ + #--data_path Morganfpenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + diff --git a/model4/run_mrgn.py b/model4/run_mrgn.py new file mode 100644 index 0000000..e994eb2 --- /dev/null +++ b/model4/run_mrgn.py @@ -0,0 +1,35 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(5) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(10): + + + + os.system(f"python train.py --encoder_type morganfp --out_dir Morganfpenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Morganfpenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + diff --git a/model4/run_prot.py b/model4/run_prot.py new file mode 100644 index 0000000..023696a --- /dev/null +++ b/model4/run_prot.py @@ -0,0 +1,51 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_id) + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + use_proteomics_data = args.use_proteomics_data + + + + # for i, seed in enumerate(seeds): + for i in range(10): + + + #os.system(f"python train_prot.py --encoder_type gnn --out_dir GNNenc/{out_dir} --data_version benchmark-data-pilot1 \ + #--data_path GNNenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + # --use_proteomics_data {use_proteomics_data}") + + os.system(f"python train_prot.py --encoder_type transformer --out_dir Trnsfenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Trnsfenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + --use_proteomics_data {use_proteomics_data}") + + #os.system(f"python train_prot.py --encoder_type morganfp --out_dir Morganfpenc/{out_dir} --data_version benchmark-data-pilot1 \ + #--data_path Morganfpenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + #--use_proteomics_data {use_proteomics_data}") + + #os.system(f"python train_prot.py --encoder_type descriptor --out_dir Descriptorenc/{out_dir} --data_version benchmark-data-pilot1 \ + #--data_path Descriptorenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i} \ + #--feature_path ../drug_features_pilot1.csv \ + #--use_proteomics_data {use_proteomics_data}") + break diff --git a/model4/run_prot.sh b/model4/run_prot.sh new file mode 100644 index 0000000..c82f41a --- /dev/null +++ b/model4/run_prot.sh @@ -0,0 +1,7 @@ +#!/bin/bash +#python run_prot.py --config config_ccle_prot.yaml +python run_prot.py --config config_ccle_gexp.yaml + +# python run_prot.py --config config_ctrpv2_prot.yaml +#python run_prot.py --config config_ctrpv2_gexp.yaml + diff --git a/model4/run_trnsf.py b/model4/run_trnsf.py new file mode 100644 index 0000000..f5d130a --- /dev/null +++ b/model4/run_trnsf.py @@ -0,0 +1,33 @@ +import os +import pickle +import argparse +from omegaconf import OmegaConf + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') + args = parser.parse_args() + + if args.config: # args priority is higher than yaml + args_ = OmegaConf.load(args.config) + OmegaConf.resolve(args_) + args=args_ + + + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + epochs = args.epochs + data_type = args.data_type + metric = args.metric + out_dir = args.out_dir + + # for i, seed in enumerate(seeds): + for i in range(0,10): + + os.system(f"python train.py --encoder_type transformer --out_dir Trnsfenc/{out_dir} --data_version benchmark-data-pilot1 \ + --data_path Trnsfenc/{out_dir}/Data --data_split_seed -10 --data_split_id {i} --metric {metric} --data_type {data_type} --epochs {epochs} --run_id {i}") + #break + diff --git a/model4/train.py b/model4/train.py new file mode 100644 index 0000000..ed647d4 --- /dev/null +++ b/model4/train.py @@ -0,0 +1,203 @@ +from data_utils import Downloader, DataProcessor +from data_utils import add_smiles +from gnn_utils import create_data_list +from torch_geometric.loader import DataLoader +from model import Model +import torch +import torch.nn as nn +from gnn_utils import EarlyStopping +# dw = Downloader('benchmark-data-imp-2023') +from gnn_utils import test_fn +import os +import argparse +import pandas as pd +from gnn_utils import CreateData +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + + +def split_df(df, seed): + + train, test = train_test_split(df, random_state=seed, test_size=0.2) + val, test = train_test_split(test, random_state=seed, test_size=0.5) + + train.reset_index(drop=True, inplace=True) + val.reset_index(drop=True, inplace=True) + test.reset_index(drop=True, inplace=True) + + return train, val, test + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(prog='ProgramName', description='What the program does') + parser.add_argument('--metric', default='auc', help='') + parser.add_argument('--run_id', default='0', help='') + parser.add_argument('--epochs', default=1, help='') + parser.add_argument('--batch_size', type=int, default=64, help='') + parser.add_argument('--data_split_seed', default=1, help='') + parser.add_argument('--data_split_id', default=0, help='') + parser.add_argument('--encoder_type', default='gnn', help='') + parser.add_argument('--data_path', default='', help='') + parser.add_argument('--data_type', default='CCLE', help='') + parser.add_argument('--data_version', default='benchmark-data-pilot1', help='benchmark-data-imp-2023 or benchmark-data-pilot1') + parser.add_argument('--out_dir', default='', help='') + parser.add_argument('--feature_path', type=str, default=None, help='') + parser.add_argument('--scale_gexp', type=bool, default=False, help='') + + args = parser.parse_args() + + + pc = DataProcessor(args.data_version) + + metric = args.metric + bs = args.batch_size + lr = 1e-4 + n_epochs = int(args.epochs) + out_dir = args.out_dir + run_id = args.run_id + data_split_seed = int(args.data_split_seed) + encoder_type = args.encoder_type + data_type = args.data_type + data_split_id = args.data_split_id + + out_dir = os.path.join(out_dir, run_id ) + os.makedirs( out_dir, exist_ok=True ) + os.makedirs( args.data_path, exist_ok=True ) + ckpt_path = os.path.join(out_dir, 'best.pt') + + dw = Downloader(args.data_version) + dw.download_candle_data(data_type=data_type, split_id=data_split_id, data_dest=args.data_path) + dw.download_deepttc_vocabs(data_dest=args.data_path) + + + + train = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='train', response_type=metric, sep="\t", + dropna=True) + + val = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='val', response_type=metric, sep="\t", + dropna=True) + + test = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='test', response_type=metric, sep="\t", + dropna=True) + + smiles_df = pc.load_smiles_data(args.data_path) + + train = add_smiles(smiles_df=smiles_df, df=train, metric=metric) + val = add_smiles(smiles_df=smiles_df, df=val, metric=metric) + test = add_smiles(smiles_df=smiles_df, df=test, metric=metric) + df_all = pd.concat([train, val, test], axis=0) + df_all.reset_index(drop=True, inplace=True) + + if data_split_seed > -1: + print("using random splitting") + train, val, test = split_df(df=df_all, seed=data_split_seed) + else: + print("using predefined splits") + + + gene_exp = pc.load_gene_expression_data(args.data_path) + + lm = pc.load_landmark_genes(args.data_path) + lm = list(set(lm).intersection(gene_exp.columns)) + gexp = gene_exp.loc[:, lm] + + if args.scale_gexp: + scgexp = StandardScaler() + gexp.loc[:,:] = scgexp.fit_transform(gexp) + + + n_descriptors=None + features=None + if args.feature_path: + print("feature path exists") + features = pd.read_csv(args.feature_path) + n_descriptors = features.shape[1] - 1 + feature_names = features.drop(['smiles'], axis=1).columns.tolist() + + test = pd.merge(test, features, on='smiles', how='left') + train = pd.merge(train, features, on='smiles', how='left') + val = pd.merge(val, features, on='smiles', how='left') + + sc = StandardScaler() + train.loc[:, feature_names] = sc.fit_transform(train.loc[:, feature_names]) + test.loc[:, feature_names] = sc.transform(test.loc[:, feature_names]) + val.loc[:, feature_names] = sc.transform(val.loc[:, feature_names]) + else: + feature_names=None + + + data_creater = CreateData(gexp=gexp, metric=metric, encoder_type=encoder_type, data_path=args.data_path, feature_names=feature_names) + + train_ds = data_creater.create_data(train) + val_ds = data_creater.create_data(val) + test_ds = data_creater.create_data(test) + + # train_ds = create_data_list(train, gexp, metric=metric) + # val_ds = create_data_list(val, gexp, metric=metric) + # test_ds = create_data_list(test, gexp, metric=metric) + + + train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=True) + val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, drop_last=False) + test_loader = DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False) + + # train_loader_no_shuffle = DataLoader(train_X, batch_size = bs, shuffle=False, drop_last=False) + # val_loader_no_shuffle = DataLoader(val_X, batch_size = bs, shuffle=False, drop_last=False) + + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = Model(gnn_features = 65, n_descriptors=n_descriptors, encoder_type=encoder_type).to(device) + adam = torch.optim.Adam(model.parameters(), lr = lr ) + optimizer = adam + + + early_stopping = EarlyStopping(patience = n_epochs, verbose=True, chkpoint_name = ckpt_path) + criterion = nn.MSELoss() + + + # train the model + hist = {"train_rmse":[], "val_rmse":[]} + for epoch in range(0, n_epochs): + model.train() + loss_all = 0 + for data in train_loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data) + output = output.reshape(-1,) + + loss = criterion(output, data.y) + loss.backward() + optimizer.step() + + + # train_rmse = gnn_utils.test_fn(train_loader, model, device) + val_rmse, _, _ = test_fn(val_loader, model, device) + early_stopping(val_rmse, model) + + if early_stopping.early_stop: + print("Early stopping") + break + + # hist["train_rmse"].append(train_rmse) + hist["val_rmse"].append(val_rmse) + # print(f'Epoch: {epoch}, Train_rmse: {train_rmse:.3}, Val_rmse: {val_rmse:.3}') + print(f'Epoch: {epoch}, Val_rmse: {val_rmse:.3}') + + # print(f"training completed at {datetime.datetime.now()}") + + model.load_state_dict(torch.load(ckpt_path)) + + test_rmse, true, pred = test_fn(test_loader, model, device) + test['true'] = true + test['pred'] = pred + + if args.feature_path: + test = test[['improve_sample_id', 'smiles', 'improve_chem_id', 'auc', 'true', 'pred']] + + test.to_csv( os.path.join(out_dir, 'test_predictions.csv'), index=False ) + # r2_score(y_pred=pred, y_true=true) diff --git a/model4/train_prot.py b/model4/train_prot.py new file mode 100644 index 0000000..1045705 --- /dev/null +++ b/model4/train_prot.py @@ -0,0 +1,232 @@ +from data_utils import Downloader, DataProcessor +from data_utils import add_smiles +from gnn_utils import create_data_list +from torch_geometric.loader import DataLoader +from model import Model +import torch +import torch.nn as nn +from gnn_utils import EarlyStopping +# dw = Downloader('benchmark-data-imp-2023') +from gnn_utils import test_fn +import os +import argparse +import pandas as pd +from gnn_utils import CreateData +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from data_utils import load_generic_expression_data + +def split_df(df, seed): + + train, test = train_test_split(df, random_state=seed, test_size=0.2) + val, test = train_test_split(test, random_state=seed, test_size=0.5) + + train.reset_index(drop=True, inplace=True) + val.reset_index(drop=True, inplace=True) + test.reset_index(drop=True, inplace=True) + + return train, val, test + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(prog='ProgramName', description='What the program does') + parser.add_argument('--metric', default='auc', help='') + parser.add_argument('--run_id', default='0', help='') + parser.add_argument('--epochs', default=1, help='') + parser.add_argument('--batch_size', type=int, default=64, help='') + parser.add_argument('--data_split_seed', default=1, help='') + parser.add_argument('--data_split_id', default=0, help='') + parser.add_argument('--encoder_type', default='gnn', help='') + parser.add_argument('--data_path', default='', help='') + parser.add_argument('--data_type', default='CCLE', help='') + parser.add_argument('--data_version', default='benchmark-data-pilot1', help='benchmark-data-imp-2023 or benchmark-data-pilot1') + parser.add_argument('--out_dir', default='', help='') + parser.add_argument('--feature_path', type=str, default=None, help='') + parser.add_argument('--scale_gexp', type=bool, default=False, help='') + parser.add_argument('--use_proteomics_data', type=int, default=1, help='if 1 use proteomics data elif 0, use gexp') + + args = parser.parse_args() + + + + pc = DataProcessor(args.data_version) + + metric = args.metric + bs = args.batch_size + lr = 1e-4 + n_epochs = int(args.epochs) + out_dir = args.out_dir + run_id = args.run_id + data_split_seed = int(args.data_split_seed) + encoder_type = args.encoder_type + data_type = args.data_type + data_split_id = args.data_split_id + + out_dir = os.path.join(out_dir, run_id ) + os.makedirs( out_dir, exist_ok=True ) + os.makedirs( args.data_path, exist_ok=True ) + ckpt_path = os.path.join(out_dir, 'best.pt') + + dw = Downloader(args.data_version) + dw.download_candle_data(data_type=data_type, split_id=data_split_id, data_dest=args.data_path) + dw.download_deepttc_vocabs(data_dest=args.data_path) + + + + train = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='train', response_type=metric, sep="\t", + dropna=True) + + val = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='val', response_type=metric, sep="\t", + dropna=True) + + test = pc.load_drug_response_data(data_path=args.data_path, data_type=data_type, + split_id=data_split_id, split_type='test', response_type=metric, sep="\t", + dropna=True) + + smiles_df = pc.load_smiles_data(args.data_path) + + train = add_smiles(smiles_df=smiles_df, df=train, metric=metric) + val = add_smiles(smiles_df=smiles_df, df=val, metric=metric) + test = add_smiles(smiles_df=smiles_df, df=test, metric=metric) + + gene_exp_ = load_generic_expression_data('proteomics_restructure_with_knn_impute.tsv') + + + if args.use_proteomics_data == 1: + print('using proteomics data') + gene_exp = gene_exp_ + else: + print('using gene expression data') + gene_exp = pc.load_gene_expression_data(args.data_path) + + + use_improve_ids = gene_exp_.index.values + train = train[train.improve_sample_id.isin(use_improve_ids)] + val = val[val.improve_sample_id.isin(use_improve_ids)] + test = test[test.improve_sample_id.isin(use_improve_ids)] + + train.reset_index(drop=True, inplace=True) + val.reset_index(drop=True, inplace=True) + test.reset_index(drop=True, inplace=True) + + + + df_all = pd.concat([train, val, test], axis=0) + df_all.reset_index(drop=True, inplace=True) + + if data_split_seed > -1: + print("using random splitting") + train, val, test = split_df(df=df_all, seed=data_split_seed) + else: + print("using predefined splits") + + + # gene_exp = pc.load_gene_expression_data(args.data_path) + + + lm = pc.load_landmark_genes(args.data_path) + lm = list(set(lm).intersection(gene_exp.columns)) + gexp = gene_exp.loc[:, lm] + + n_genes = len(lm) + + if args.scale_gexp: + scgexp = StandardScaler() + gexp.loc[:,:] = scgexp.fit_transform(gexp) + + + n_descriptors=None + features=None + if args.feature_path: + print("feature path exists") + features = pd.read_csv(args.feature_path) + n_descriptors = features.shape[1] - 1 + feature_names = features.drop(['smiles'], axis=1).columns.tolist() + + test = pd.merge(test, features, on='smiles', how='left') + train = pd.merge(train, features, on='smiles', how='left') + val = pd.merge(val, features, on='smiles', how='left') + + sc = StandardScaler() + train.loc[:, feature_names] = sc.fit_transform(train.loc[:, feature_names]) + test.loc[:, feature_names] = sc.transform(test.loc[:, feature_names]) + val.loc[:, feature_names] = sc.transform(val.loc[:, feature_names]) + else: + feature_names=None + + + data_creater = CreateData(gexp=gexp, metric=metric, encoder_type=encoder_type, data_path=args.data_path, feature_names=feature_names) + + train_ds = data_creater.create_data(train) + val_ds = data_creater.create_data(val) + test_ds = data_creater.create_data(test) + + # train_ds = create_data_list(train, gexp, metric=metric) + # val_ds = create_data_list(val, gexp, metric=metric) + # test_ds = create_data_list(test, gexp, metric=metric) + + + train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=True) + val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, drop_last=False) + test_loader = DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False) + + # train_loader_no_shuffle = DataLoader(train_X, batch_size = bs, shuffle=False, drop_last=False) + # val_loader_no_shuffle = DataLoader(val_X, batch_size = bs, shuffle=False, drop_last=False) + + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = Model(gnn_features = 65, n_descriptors=n_descriptors, encoder_type=encoder_type, + n_genes=n_genes).to(device) + adam = torch.optim.Adam(model.parameters(), lr = lr ) + optimizer = adam + + + early_stopping = EarlyStopping(patience = n_epochs, verbose=True, chkpoint_name = ckpt_path) + criterion = nn.MSELoss() + + + # train the model + hist = {"train_rmse":[], "val_rmse":[]} + for epoch in range(0, n_epochs): + model.train() + loss_all = 0 + for data in train_loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data) + output = output.reshape(-1,) + + loss = criterion(output, data.y) + loss.backward() + optimizer.step() + + + # train_rmse = gnn_utils.test_fn(train_loader, model, device) + val_rmse, _, _ = test_fn(val_loader, model, device) + early_stopping(val_rmse, model) + + if early_stopping.early_stop: + print("Early stopping") + break + + # hist["train_rmse"].append(train_rmse) + hist["val_rmse"].append(val_rmse) + # print(f'Epoch: {epoch}, Train_rmse: {train_rmse:.3}, Val_rmse: {val_rmse:.3}') + print(f'Epoch: {epoch}, Val_rmse: {val_rmse:.3}') + + # print(f"training completed at {datetime.datetime.now()}") + + model.load_state_dict(torch.load(ckpt_path)) + + test_rmse, true, pred = test_fn(test_loader, model, device) + test['true'] = true + test['pred'] = pred + + if args.feature_path: + test = test[['improve_sample_id', 'smiles', 'improve_chem_id', 'auc', 'true', 'pred']] + + test.to_csv( os.path.join(out_dir, 'test_predictions.csv'), index=False ) + # r2_score(y_pred=pred, y_true=true) diff --git a/model4/transformer.py b/model4/transformer.py new file mode 100644 index 0000000..e9d0be2 --- /dev/null +++ b/model4/transformer.py @@ -0,0 +1,352 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import pickle + + +class MultiheadAttention(nn.Module): + + def __init__(self, input_dim, embed_dim, num_heads): + super().__init__() + assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads." + + self.embed_dim = embed_dim # 512 + self.num_heads = num_heads # 8 + self.head_dim = embed_dim // num_heads + + # Stack all weight matrices 1...h together for efficiency + # Note that in many implementations you see "bias=False" which is optional + self.qkv_proj = nn.Linear(input_dim, 3*embed_dim) + self.o_proj = nn.Linear(embed_dim, embed_dim) + + self._reset_parameters() + + def _reset_parameters(self): + # Original Transformer initialization, see PyTorch documentation + nn.init.xavier_uniform_(self.qkv_proj.weight) + self.qkv_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.o_proj.weight) + self.o_proj.bias.data.fill_(0) + + def forward(self, x, mask=None, attn_bias=None, return_attention=False): + batch_size, seq_length, _ = x.size() + # mask = x.eq(0) + # if mask is not None: + # mask = expand_mask(mask) + x = x * (1 - mask.unsqueeze(-1).type_as(x)) + qkv = self.qkv_proj(x) + + # Separate Q, K, V from linear output + qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim) + qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims] + q, k, v = qkv.chunk(3, dim=-1) + + # Determine value outputs + # values, attention = scaled_dot_product(q, k, v, mask=mask) + d_k = q.size()[-1] + attn_logits = torch.matmul(q, k.transpose(-2, -1)) + attn_logits = attn_logits * ( d_k**(-0.5) ) + attn_logits = attn_logits.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float('-inf') ) + attn_logits = attn_logits.view(batch_size*self.num_heads, seq_length, seq_length) + + + # attention_mask = torch.randn(4,4,4) + + if attn_bias: + attn_bias = attn_bias.unsqueeze(1) + attn_bias = attn_bias.repeat(1, self.num_heads, seq_length, seq_length) + attn_bias = attn_bias.view(batch_size*self.num_heads, seq_length, seq_length) + attn_logits = attn_logits + attn_bias + + attn_logits = attn_logits.view(batch_size, self.num_heads, seq_length, seq_length) + attention = F.softmax(attn_logits, dim=-1) + + values = torch.matmul(attention, v) + values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims] + values = values.reshape(batch_size, seq_length, self.embed_dim) + + + # values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims] + # values = values.reshape(batch_size, seq_length, self.embed_dim) + o = self.o_proj(values) + + if return_attention: + return o, attention + else: + return o + + +class EncoderBlock(nn.Module): + + def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0): + """ + Inputs: + input_dim - Dimensionality of the input + num_heads - Number of heads to use in the attention block + dim_feedforward - Dimensionality of the hidden layer in the MLP + dropout - Dropout probability to use in the dropout layers + """ + super().__init__() + + # Attention layer + self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads) + + # Two-layer MLP + self.linear_net = nn.Sequential( + nn.Linear(input_dim, dim_feedforward), + nn.Dropout(dropout), + nn.ReLU(inplace=True), + nn.Linear(dim_feedforward, input_dim) + ) + + # Layers to apply in between the main layers + self.norm1 = nn.LayerNorm(input_dim) + self.norm2 = nn.LayerNorm(input_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None, attn_bias=None): + # Attention part + attn_out = self.self_attn(x, mask=mask, attn_bias=attn_bias) + x = x + self.dropout(attn_out) + x = self.norm1(x) + + # MLP part + linear_out = self.linear_net(x) + x = x + self.dropout(linear_out) + x = self.norm2(x) + + return x + + +class TransformerEncoder(nn.Module): + + def __init__(self, num_layers, **block_args): + super().__init__() + self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)]) + + def forward(self, x, mask=None, attn_bias=None): + for l in self.layers: + x = l(x, mask=mask, attn_bias=attn_bias) + return x + + def get_attention_maps(self, x, mask=None, attn_bias=None): + attention_maps = [] + for l in self.layers: + _, attn_map = l.self_attn(x, mask=mask, return_attention=True) + attention_maps.append(attn_map) + x = l(x) + return attention_maps +import math +from torch import Tensor + +# from Moformer +# class PositionalEncoding(nn.Module): + +# def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 2048): +# super().__init__() +# self.dropout = nn.Dropout(p=dropout) + +# position = torch.arange(max_len).unsqueeze(1) +# div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) +# pe = torch.zeros(max_len, 1, d_model) +# pe[:, 0, 0::2] = torch.sin(position * div_term) +# pe[:, 0, 1::2] = torch.cos(position * div_term) +# self.register_buffer('pe', pe) + +# def forward(self, x: Tensor) -> Tensor: +# """ +# Args: +# x: Tensor, shape [seq_len, batch_size, embedding_dim] +# """ +# x = x + self.pe[:x.size(0)] +# return self.dropout(x) + +# from UVDALC +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, max_len=5000): + """ + Inputs + d_model - Hidden dimensionality of the input. + max_len - Maximum length of a sequence to expect. + """ + super().__init__() + + # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + + # register_buffer => Tensor which is not a parameter, but should be part of the modules state. + # Used for tensors that need to be on the same device as the module. + # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model) + self.register_buffer('pe', pe, persistent=False) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + +class TransformerModel(nn.Module): + + def __init__(self, args): + super().__init__() + num_layers=6 + input_dim=512 + model_dim=512 + dim_feedforward=2*512 + num_heads=8 + dropout=0 + num_classes=1 + input_dropout=0 + vocab_size = args['vocab_size'] + self.masked_token_train=args['masked_token_train'] + self.finetune = args['finetune'] + +# with open(dictionary_path, 'rb') as f: +# self.dictionary = pickle.load(f) + + self.embed = nn.Embedding(vocab_size, 512, 0) + + # self.input_net = nn.Sequential( + # nn.Dropout(input_dropout), + # nn.Linear(input_dim, hparams.model_dim) + # ) + + self.transformer = TransformerEncoder(num_layers=num_layers, + input_dim=model_dim, + dim_feedforward=2*model_dim, + num_heads=num_heads, + dropout=dropout) + + self.masked_model = MaskLMHead(embed_dim=512, + output_dim=vocab_size + ) + + + + self.regression_head = ClassificationHead(input_dim=512, + inner_dim=256, + num_classes=1) + + self.output_net = nn.Sequential( + nn.Linear(model_dim, model_dim), + nn.LayerNorm(model_dim), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(model_dim, num_classes) + ) + + + self.pos_encoder = PositionalEncoding(d_model=512, max_len=50) + + def forward(self, data ): + attn_bias=None + tokens = data.tokens + + + """ + Inputs: + x - Input features of shape [Batch, SeqLen, input_dim] + mask - Mask to apply on the attention outputs (optional) + add_positional_encoding - If True, we add the positional encoding to the input. + Might not be desired for some tasks. + """ + + # tokens = batch['tokens'] + + # masked_token_train=True + token_mask=None + if self.masked_token_train: + + rand = torch.rand(tokens.shape) + # where the random array is less than 0.15, we set true + mask_arr = rand < 0.15 + token_mask = mask_arr*(tokens !=1)*(tokens!=0)*(tokens!=2) + + tokens = tokens.masked_fill(token_mask.to(torch.bool), 4) + + + + + + padding_mask = tokens.eq(0) + x = self.embed(tokens) + x = self.pos_encoder(x) # adding positional encoding from Moformer paper https://github.com/zcao0420/MOFormer/blob/main/model/transformer.py + # x = self.input_net(x) + # if add_positional_encoding: + # x = self.positional_encoding(x) + + # encoder output + encoder_out = self.transformer(x, mask=padding_mask, attn_bias=None) + + if self.masked_token_train: + logits = self.masked_model(encoder_out, token_mask) + else: + logits = encoder_out + + + if self.finetune: + logits = self.regression_head(logits) + + + + # x = self.output_net(x) + return logits, token_mask, encoder_out[:,0] + + +class MaskLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = F.gelu + # self.layer_norm = LayerNorm(embed_dim) + self.out = nn.Linear(embed_dim, output_dim) + + # if weight is None: + # weight = nn.Linear(embed_dim, output_dim, bias=False).weight + # self.weight = weight + # self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the masked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + # x = self.layer_norm(x) + # project back to size of vocabulary with bias + x = self.out(x) + return x + + +class ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + pooler_dropout=0.0, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = F.relu + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x