-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathDataset.py
335 lines (252 loc) · 16.4 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import os
import numpy as np
import torch
import json
import glob
from torch_geometric.data import Dataset, Data
class PDBbind_Dataset(Dataset):
"""
A class used to represent a Dataset for protein-ligand interaction graphs. Takes as input a folder containing the graph.pth file for each complex in the dataset.
- Loads all graphs from a folder
- If a data split dictionary is given, only the graphs that are included in the key that is provided in the dataset parameter are loaded. If no dict is given, all graphs are processed
- If inference is set to True, labels are set to 0
- If inference is set to False, a data dictionary containing the affinity labels for each complex has to be provided data_dict[complex_id] = {'log_kd_ki': affinity}
- For ablation studies, there is an option to delete all protein_nodes from the graphs
- To measure the contribution of atomic features and edge features, these can be excluded from the graph
- A masternode can be included in the graph, which is connected to all nodes in the graph. The masternode can be connected to all nodes, only to ligand nodes or only to protein nodes.
- The edges between the masternode and the other nodes can be directed from the ligand/protein nodes to the masternode, from the masternode to the ligand/protein nodes or undirected.
- If a masternode is included, ablation tests can also include the deletion of ligand nodes from the graph.
"""
def __init__(self,
root, # Path to the folder containing the graphs
# EMBEDDINGS TO BE INCLUDED
protein_embeddings, # List of all protein embeddings that should be included (have to be saved in the graph Data() object already)
ligand_embeddings, # List of all ligand embeddings that should be included (have to be saved in the graph Data() object already).
# WHICH GRAPHS AND LABELS TO INCLUDE
data_dict=None, # Path to dictionary containing the affinity labels of the complexes as dict[complex_id] = {'log_kd_ki': affinity}
data_split=None, # Filepath to dictionary (json file) containing the data split for the graphs in the folder
dataset=None, # If a split dict is given, which subset should be loaded ['train', 'test'] as defined in the data_split file
# ABLATION
delete_protein = False, # If protein nodes should be deleted from the graph (ablation study)
delete_ligand = False, # If ligand nodes should be deleted from the graph (ablation study)
edge_features=True, # If edge features should be included
atom_features=True, # If atom features should be included
# INCLUDE A MASTERNODE
masternode=False, # If a masternode (mn) should be included
masternode_connectivity = 'all', # If a mn is included, to which nodes it should be connected ('all', 'ligand', 'protein')
masternode_edges='undirected', # If the mn should be connected with undirected or directed edges ("undirected", "in", or "out")
):
super().__init__(root)
# Data Directory and Embeddings to be included
self.data_dir = root
self.protein_embeddings = protein_embeddings
self.ligand_embeddings = ligand_embeddings
# Ablation Studies
self.delete_protein = delete_protein
self.delete_ligand = delete_ligand
if data_dict is not None:
with open(data_dict, 'r', encoding='utf-8') as json_file:
self.data_dict = json.load(json_file)
self.labels = True
else:
self.labels = False
# Load the dictionary containing the data split for the dataset, if one is given
# In no splitting dict is given, include all graphs in the folder in the dataset
if data_split:
self.dataset = dataset
self.data_split = data_split
with open(self.data_split, 'r', encoding='utf-8') as json_file:
self.split_dict = json.load(json_file)
# Generate the list of filepaths to the graphs that should be loaded,
# including only the complexes that are included in the split dict and
# all their associated ligands
included_complexes = self.split_dict[self.dataset]
dataset_filepaths = []
for id in included_complexes:
search_pattern = os.path.join(self.data_dir, f"{id}*_graph.pth")
matching_files = glob.glob(search_pattern)
dataset_filepaths.extend(matching_files)
self.filepaths = dataset_filepaths
else:
self.filepaths = [file.path for file in os.scandir(self.data_dir) if file.name.endswith('graph.pth')]
print("-- Number of graphs loaded: ", len(self.filepaths))
#------------------------------------------------------------------------------------------
# Process all the graphs according to kwargs
# -------------------------------------------------------------------------------------------
self.input_data = {}
ind = 0
for file in self.filepaths:
grph = torch.load(file)
id = grph.id
pos = grph.pos
if self.labels:
min=0
max=16
try: # If the labels are saved with L00001 in the dictionary
pK = self.data_dict[id]['log_kd_ki']
except KeyError: # If the labels are saved without L00001 in the dictionary
id_short = id[:-17]
pK = self.data_dict[id_short]['log_kd_ki']
pK_scaled = (pK - min) / (max - min)
else: pK_scaled = 0
# --- AMINO ACID EMBEDDINGS ---
x = grph.x
# Append the amino acid embeddings to the feature matrices
for emb in self.protein_embeddings:
emb_tensor = grph[emb]
if emb_tensor is not None:
x = torch.concatenate((x, emb_tensor), axis=1)
# --- LIGAND EMBEDDINGS ---
ligand_embedding = None
# Concatenate all ligand embeddings into a single vector
for emb in self.ligand_embeddings:
emb_vector = grph[emb]
if emb_vector is None: print(f"Embedding {emb} not found for {id}")
else:
if ligand_embedding is None: ligand_embedding = emb_vector
else:
ligand_embedding = torch.concatenate((ligand_embedding, emb_vector), axis=1)
ligand_embedding = ligand_embedding.float()
# --- EDGE INDECES, EDGE ATTRIBUTES---
# for convolution on 1) all edges, 2) only ligand edges and 3) only protein edges
edge_index = grph.edge_index
edge_index_lig = grph.edge_index_lig
edge_index_prot = grph.edge_index_prot
edge_attr = grph.edge_attr
edge_attr_lig = grph.edge_attr_lig
edge_attr_prot = grph.edge_attr_prot
# If a masternode should be included in the graph, add the corresponding edge_index
if masternode:
# Depending on the desired masternode connectivity (to all nodes, to ligand nodes or to protein nodes),
# choose the correct edge index master from the graph object
if masternode_connectivity == 'all': edge_index_master = grph.edge_index_master
elif masternode_connectivity == 'ligand': edge_index_master = grph.edge_index_master_lig
elif masternode_connectivity == 'protein': edge_index_master = grph.edge_index_master_prot
else: raise ValueError(f"Invalid value for masternode_connectivity: {masternode_connectivity}")
# By default, the edge_index_master contains directed edges for information flow from the
# ligand and protein nodes to the masternode ("in")
if masternode_edges == 'in':
pass
# For information flow from the masternode to the ligand and protein nodes ("out"), swap the rows
elif masternode_edges == 'out':
edge_index_master = edge_index_master[[1, 0], :]
# For information flow in both directions ("undirected"), swap the rows and append
elif masternode_edges == 'undirected':
edge_index_master = torch.concatenate((edge_index_master[:,:-1], edge_index_master[[1, 0], :]), dim=1)
else: raise ValueError(f"Invalid value for masternode_edges: {masternode_edges}")
# Append the updated edge_index_master to the edge_index containing the other edges in the graph
edge_index = torch.concatenate((edge_index, edge_index_master), dim=1)
edge_index_lig = torch.concatenate((edge_index_lig, edge_index_master), dim=1)
edge_index_prot = torch.concatenate((edge_index_prot, edge_index_master), dim=1)
# For each edge that has been added to connect the masternode, extend also the edge attribute
# matrix with a feature vector
mn_edge_attr = torch.tensor([0., 1., 0., # it's a mn connection
0., 0., 0.,0., # length is zero
0., 0., 0.,0.,0., # bondtype = None
0., # is not conjugated
0., # is not in ring
0., 0., 0., 0., 0., 0.], # No stereo
dtype=torch.float)
mn_edge_matrix = mn_edge_attr.repeat(edge_index_master.shape[1], 1)
edge_attr = torch.concatenate([edge_attr, mn_edge_matrix], axis=0)
edge_attr_lig = torch.concatenate([edge_attr_lig, mn_edge_matrix], axis=0)
edge_attr_prot = torch.concatenate([edge_attr_prot, mn_edge_matrix], axis=0)
# If NO masternode should be included in the graph, remove the corresponding rows from pos and x
else:
x = x[:-1, :]
pos = pos[:-1, :]
# --- ABLATION STUDIES ---
# For ablation studies: Generate feature matrices without node/edge features
if not atom_features:
x = torch.concatenate((x[:, 0:9], x[:, 40:]), dim=1)
if not edge_features:
edge_attr = edge_attr[:, :7]
edge_attr_lig = edge_attr_lig[:, :7]
edge_attr_prot = edge_attr_prot[:, :7]
n_prot_nodes = grph.edge_index_master_prot.shape[1] - 1
n_lig_nodes = grph.edge_index_master_lig.shape[1] - 1
n_nodes = grph.edge_index_master.shape[1] - 1
# For ablation studies: Generate graphs with all protein nodes removed
if delete_protein and delete_ligand: raise ValueError('Cannot delete both protein and ligand nodes')
elif delete_protein and masternode:
# Remove all nodes that don't belong to the ligand from feature matrix, keep masternode
x = torch.concatenate( [x[:n_lig_nodes,:] , x[-1,:].view(1,-1)] )
# Remove all coordinates of nodes that don't belong to the ligand and keep masternode
pos = torch.concatenate( [pos[:n_lig_nodes,:] , pos[-1,:].view(1,-1)] )
# Keep only edges that are between ligand atoms or between ligand atoms and masternode
mask = ((edge_index < n_lig_nodes) | (edge_index == n_nodes)).all(dim=0)
edge_index = edge_index[:, mask]
edge_index[edge_index == n_nodes] = n_lig_nodes
edge_attr = edge_attr[mask, :]
train_graph = Data(x = x.float(),
edge_index=edge_index.long(),
edge_attr=edge_attr.float(),
y=torch.tensor(pK_scaled, dtype=torch.float),
n_nodes=torch.tensor([n_nodes, n_lig_nodes, n_prot_nodes], dtype=torch.long),
lig_emb=ligand_embedding
#,pos=pos
,id=id
)
elif delete_protein and not masternode:
# Remove all nodes that don't belong to the ligand from feature matrix
x = x[:n_lig_nodes,:]
# Remove all coordinates of nodes that don't belong to the ligand
pos = pos[:n_lig_nodes,:]
# Keep only edges that are between ligand atoms
mask = (edge_index < n_lig_nodes).all(dim=0)
edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask, :]
train_graph = Data(x = x.float(),
edge_index=edge_index.long(),
edge_attr=edge_attr.float(),
y=torch.tensor(pK_scaled, dtype=torch.float),
n_nodes=torch.tensor([n_nodes, n_lig_nodes, n_prot_nodes], dtype=torch.long),
lig_emb=ligand_embedding
#,pos=pos
,id=id
)
elif delete_ligand and not masternode: raise ValueError('Cannot delete ligand nodes without masternode')
elif delete_ligand and masternode:
# Remove all nodes that don't belong to the ligand from feature matrix, keep masternode
x = x[n_lig_nodes:, :]
# Remove all coordinates of nodes that don't belong to the ligand and keep masternode (for visualization only)
grph.pos = grph.pos[n_lig_nodes:, :]
# Keep only edges that are between protein nodes and the masternode
mask = torch.all(edge_index >= n_lig_nodes, dim=0)
edge_index = edge_index[:, mask] - n_lig_nodes
edge_attr = edge_attr[mask, :]
train_graph = Data(x = x.float(),
edge_index=edge_index.long(),
edge_attr=edge_attr.float(),
y=torch.tensor(pK_scaled, dtype=torch.float),
n_nodes=torch.tensor([n_nodes, n_lig_nodes, n_prot_nodes], dtype=torch.long),
lig_emb=ligand_embedding
#,pos=pos
,id=id
)
# --- NO ABLATION - Complete Interaction Graphs ---
else:
train_graph = Data(x = x.float(),
edge_index=edge_index.long(),
edge_attr=edge_attr.float(),
# To do: If we want to do convolution on only ligand or protein edges,
# we need to pass the corresponding edge_index conaining these edges, but in such
# architectures we can't do the ablation anymore because there we don't have
# any edge_index_lig or edge_index_prot.
#edge_index_lig=edge_index_lig,
#edge_index_prot=edge_index_prot,
#edge_attr_lig=edge_attr_lig,
#edge_attr_prot=edge_attr_prot,
y=torch.tensor(pK_scaled, dtype=torch.float),
n_nodes=torch.tensor([n_nodes, n_lig_nodes, n_prot_nodes], dtype=torch.long),
lig_emb=ligand_embedding
#,pos=pos
,id=id
)
self.input_data[ind] = train_graph
ind += 1
def len(self):
return len(self.input_data)
def get(self, idx):
graph = self.input_data[idx]
return graph