-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdeepgozero_data.py
145 lines (126 loc) · 4.66 KB
/
deepgozero_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import click as ck
import numpy as np
import pandas as pd
from collections import Counter, deque
from utils import Ontology, FUNC_DICT, NAMESPACES, MOLECULAR_FUNCTION, BIOLOGICAL_PROCESS, CELLULAR_COMPONENT
import logging
logging.basicConfig(level=logging.INFO)
@ck.command()
@ck.option(
'--go-file', '-gf', default='data/go.obo',
help='Gene Ontology file in OBO Format')
@ck.option(
'--data-file', '-df', default='data/swissprot_exp.pkl',
help='Uniprot KB, generated with uni2pandas.py')
@ck.option(
'--sim-file', '-sf', default='data/swissprot_exp.sim',
help='Sequence similarity generated with Diamond')
def main(go_file, data_file, sim_file):
go = Ontology(go_file, with_rels=True)
logging.info('GO loaded')
df = pd.read_pickle(data_file)
proteins = set(df['proteins'].values)
print("DATA FILE" ,len(df))
logging.info('Processing annotations')
annotations = list()
for ont in ['mf', 'bp', 'cc']:
cnt = Counter()
iprs = Counter()
index = []
for i, row in enumerate(df.itertuples()):
ok = False
for term in row.prop_annotations:
if go.get_namespace(term) == NAMESPACES[ont]:
cnt[term] += 1
ok = True
for ipr in row.interpros:
iprs[ipr] += 1
if ok:
index.append(i)
del cnt[FUNC_DICT[ont]] # Remove top term
tdf = df.iloc[index]
terms = list(cnt.keys())
interpros = list(iprs.keys())
print(f'Number of {ont} terms {len(terms)}')
print(f'Number of {ont} iprs {len(iprs)}')
print(f'Number of {ont} proteins {len(tdf)}')
terms_df = pd.DataFrame({'gos': terms})
terms_df.to_pickle(f'data/{ont}/terms.pkl')
iprs_df = pd.DataFrame({'interpros': interpros})
# iprs_df.to_pickle(f'data/{ont}/interpros.pkl')
# Split train/valid/test
proteins = tdf['proteins']
prot_set = set(proteins)
prot_idx = {v:k for k, v in enumerate(proteins)}
sim = {}
train_prots = set()
with open(sim_file) as f:
for line in f:
it = line.strip().split('\t')
p1, p2, score = it[0], it[1], float(it[2]) / 100.0
if p1 == p2:
continue
if score < 0.5: # Comment this for hard split
continue
if p1 not in prot_set or p2 not in prot_set:
continue
if p1 not in sim:
sim[p1] = []
if p2 not in sim:
sim[p2] = []
sim[p1].append(p2)
sim[p2].append(p1)
used = set()
def dfs(prot):
stack = deque()
stack.append(prot)
used.add(prot)
prots = []
while len(stack) > 0:
prot = stack.pop()
prots.append(prot)
used.add(prot)
if prot in sim:
for p in sim[prot]:
if p not in used:
used.add(p)
stack.append(p)
return prots
groups = []
for p in proteins:
if p not in used:
group = dfs(p)
groups.append(group)
s = 0
for g in groups:
s += len(g)
print(len(proteins), len(groups), s)
index = np.arange(len(groups))
np.random.seed(seed=0)
np.random.shuffle(index)
train_n = int(len(groups) * 0.9)
valid_n = int(train_n * 0.9)
train_index = []
valid_index = []
test_index = []
for idx in index[:valid_n]:
for prot in groups[idx]:
train_index.append(prot_idx[prot])
for idx in index[valid_n:train_n]:
for prot in groups[idx]:
valid_index.append(prot_idx[prot])
for idx in index[train_n:]:
for prot in groups[idx]:
test_index.append(prot_idx[prot])
train_index = np.array(train_index)
valid_index = np.array(valid_index)
test_index = np.array(test_index)
train_df = tdf.iloc[train_index]
train_df.to_pickle(f'data/{ont}/train_data_hard.pkl')
valid_df = tdf.iloc[valid_index]
valid_df.to_pickle(f'data/{ont}/valid_data_hard.pkl')
test_df = tdf.iloc[test_index]
test_df.to_pickle(f'data/{ont}/test_data_hard.pkl')
print(f'Train/Valid/Test proteins for {ont} {len(train_df)}/{len(valid_df)}/{len(test_df)}')
if __name__ == '__main__':
main()