-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgodag_plot.py
339 lines (310 loc) · 13.2 KB
/
godag_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
"""Plot a GODagSmall."""
__copyright__ = "Copyright (C) 2016-2018, DV Klopfenstein, H Tang, All rights reserved."
__author__ = "DV Klopfenstein"
import sys
import os
import collections as cx
from collections import OrderedDict
from goatools.godag_obosm import OboToGoDagSmall
def plot_gos(fout_png, goids, obo_dag, *args, **kws):
"""Given GO ids and the obo_dag, create a plot of paths from GO ids."""
engine = kws['engine'] if 'engine' in kws else 'pydot'
godagsmall = OboToGoDagSmall(goids=goids, obodag=obo_dag).godag
# godagplot = GODagSmallPlot(godagsmall, *args, **kws)
godagplot = GODagSmallPlot(godagsmall, obo=obo_dag, *args, **kws) # add by gdq
godagplot.plt(fout_png, engine)
def plot_goid2goobj(fout_png, goid2goobj, *args, **kws):
"""Given a dict containing GO id and its goobj, create a plot of paths from GO ids."""
engine = kws['engine'] if 'engine' in kws else 'pydot'
godagsmall = OboToGoDagSmall(goid2goobj=goid2goobj).godag
godagplot = GODagSmallPlot(godagsmall, *args, **kws)
godagplot.plt(fout_png, engine)
def plot_results(fout_png, goea_results, *args, **kws):
"""Given a list of GOEA results, plot result GOs up to top."""
if "{NS}" not in fout_png:
plt_goea_results(fout_png, goea_results, *args, **kws)
else:
# Plot separately by NS: BP, MF, CC
ns2goea_results = cx.defaultdict(list)
for rec in goea_results:
ns2goea_results[rec.NS].append(rec)
for ns_name, ns_res in ns2goea_results.items():
png = fout_png.format(NS=ns_name)
plt_goea_results(png, ns_res, *args, **kws)
def plt_goea_results(fout_png, goea_results, *args, **kws):
"""Plot a single page."""
engine = kws['engine'] if 'engine' in kws else 'pydot'
godagsmall = OboToGoDagSmall(goea_results=goea_results).godag
godagplot = GODagSmallPlot(godagsmall, *args, goea_results=goea_results, **kws)
godagplot.plt(fout_png, engine)
class GODagPltVars(object):
"""Holds plotting paramters."""
# http://www.graphviz.org/doc/info/colors.html
rel2col = {
'is_a': 'black',
'part_of': 'blue',
'regulates': 'gold',
'positively_regulates': 'green',
'negatively_regulates': 'red',
'occurs_in': 'aquamarine4',
'capable_of': 'dodgerblue',
'capable_of_part_of': 'darkorange',
}
alpha2col = OrderedDict([
# GOEA GO terms that are significant
(0.005, 'mistyrose'),
(0.010, 'moccasin'),
(0.050, 'lemonchiffon1'),
# GOEA GO terms that are not significant
(1.000, 'grey95'),
])
key2col = {
'level_01': 'lightcyan',
'go_sources': 'palegreen',
}
total_color_dict = {v: k for k, v in rel2col.items()} # added by gdq
for k, v in alpha2col.items():
if k == 0.005:
k = 'pvalue=(0, 0.005]'
elif k == 0.010:
k = 'pvalue=(0.005, 0.01]'
elif k == 0.050:
k = 'pvalue=(0.01, 0.05]'
else:
k = 'pvalue=(0.05, 1]'
total_color_dict[v] = k
for k, v in key2col.items():
if k == "level_01":
k = 'level=L01'
total_color_dict[v] = k
fmthdr = "{GO} L{level:>02} D{depth:>02}"
fmtres = "{study_count} genes"
# study items per line on GO Terms:
items_p_line = 5
class GODagSmallPlot(object):
"""Plot a graph contained in an object of type GODagSmall ."""
def __init__(self, godagsmall, *args, **kws):
self.args = args
self.log = kws['log'] if 'log' in kws else sys.stdout
self.title = kws['title'] if 'title' in kws else None
# GOATOOLs results as objects
self.go2res = self._init_go2res(**kws)
# GOATOOLs results as a list of namedtuples
self.pval_name = self._init_pval_name(**kws)
# Gene Symbol names
self.id2symbol = kws['id2symbol'] if 'id2symbol' in kws else {}
self.study_items = kws['study_items'] if 'study_items' in kws else None
self.study_items_max = self._init_study_items_max()
self.alpha_str = kws['alpha_str'] if 'alpha_str' in kws else None
self.pltvars = kws['GODagPltVars'] if 'GODagPltVars' in kws else GODagPltVars()
if 'items_p_line' in kws:
self.pltvars.items_p_line = kws['items_p_line']
self.dpi = kws['dpi'] if 'dpi' in kws else 150
self.godag = godagsmall
self.goid2color = self._init_goid2color()
self.pydot = None
self.obo = kws['obo'] if 'obo' in kws else dict() # add by gdq
def _init_study_items_max(self):
"""User can limit the number of genes printed in a GO term."""
if self.study_items is None:
return None
if self.study_items is True:
return None
if isinstance(self.study_items, int):
return self.study_items
return None
@staticmethod
def _init_go2res(**kws):
"""Initialize GOEA results."""
if 'goea_results' in kws:
return {res.GO:res for res in kws['goea_results']}
if 'go2nt' in kws:
return kws['go2nt']
@staticmethod
def _init_pval_name(**kws):
"""Initialize pvalue attribute name."""
if 'pval_name' in kws:
return kws['pval_name']
if 'goea_results' in kws:
goea = kws['goea_results']
if goea:
return "p_{M}".format(M=goea[0].method_flds[0].fieldname)
def _init_goid2color(self):
"""Set colors of GO terms."""
goid2color = {}
# 1. colors based on p-value override colors based on source GO
if self.go2res is not None:
alpha2col = self.pltvars.alpha2col
pval_name = self.pval_name
for goid, res in self.go2res.items():
pval = getattr(res, pval_name, None)
if pval is not None:
for alpha, color in alpha2col.items():
if pval <= alpha and res.study_count != 0:
if goid not in goid2color:
goid2color[goid] = color
# 2. GO source color
color = self.pltvars.key2col['go_sources']
for goid in self.godag.go_sources:
if goid not in goid2color:
goid2color[goid] = color
# 3. Level-01 GO color
color = self.pltvars.key2col['level_01']
for goid, goobj in self.godag.go2obj.items():
if goobj.level == 1:
if goid not in goid2color:
goid2color[goid] = color
return goid2color
def plt(self, fout_img, engine="pydot"):
"""Plot using pydot, graphviz, or GML."""
if engine == "pydot":
self._plt_pydot(fout_img)
elif engine == "pygraphviz":
raise Exception("TO BE IMPLEMENTED SOON: ENGINE pygraphvis")
else:
raise Exception("UNKNOWN ENGINE({E})".format(E=engine))
# ----------------------------------------------------------------------------------
# pydot
def _plt_pydot(self, fout_img):
"""Plot using the pydot graphics engine."""
dag = self._get_pydot_graph()
img_fmt = os.path.splitext(fout_img)[1][1:]
dag.write(fout_img, format=img_fmt)
self.log.write(" {GO_USR:>3} usr {GO_ALL:>3} GOs WROTE: {F}\n".format(
F=fout_img,
GO_USR=len(self.godag.go_sources),
GO_ALL=len(self.godag.go2obj)))
def _get_pydot_graph(self):
"""Given a DAG, return a pydot digraph object."""
rel = "is_a"
pydot = self._get_pydot()
# Initialize empty dag
dag = pydot.Dot(label=self.title, graph_type='digraph', dpi="{}".format(self.dpi), compound='true', labeloc='t')
# Initialize nodes
go2node = self._get_go2pydotnode()
# Add nodes to graph
for node in go2node.values():
dag.add_node(node)
# Add edges to graph
rel2col = self.pltvars.rel2col
for src, tgt in self.godag.get_edges():
# add by gdq: find relationship, 可能没有用,因为作者仅仅用is_a定义parent节点 --
if src in self.obo and tgt in self.obo:
for rel_type, set_content in self.obo[src].relationship.items():
set_content_ids = {x.id for x in set_content}
if tgt in set_content_ids:
print('YES, Found One')
rel = rel_type
break
# end adding --
dag.add_edge(pydot.Edge(
go2node[tgt], go2node[src],
shape="normal",
color=rel2col[rel],
dir="back")) # invert arrow direction for obo dag convention
# added by gdq for legend:
subgraph = pydot.Cluster('Legend', label="Legend")
dag.add_subgraph(subgraph)
for k, v in self.legend_nodes.items():
subgraph.add_node(v)
return dag
def _get_go2pydotnode(self):
"""Create pydot Nodes."""
used_color = list() # add by gdq
go2node = {}
for goid, goobj in self.godag.go2obj.items():
txt = self._get_node_text(goid, goobj)
fillcolor = self.goid2color.get(goid, "white")
if fillcolor not in used_color:
used_color.append(fillcolor) # add by gdq
node = self.pydot.Node(
txt,
shape="box",
style="rounded, filled",
fillcolor=fillcolor,
color="mediumseagreen")
go2node[goid] = node
# add nodes as legend by gdq
self.legend_nodes = dict()
print(used_color)
for each in used_color:
if each == 'white':
continue
node_txt = self.pltvars.total_color_dict[each]
node = self.pydot.Node(
node_txt,
shape="note",
style="filled",
fillcolor=each,
color="mediumseagreen",
)
self.legend_nodes[node_txt] = node
return go2node
def _get_pydot(self):
"""Return pydot package. Load pydot, if necessary."""
if self.pydot:
return self.pydot
self.pydot = __import__("pydot")
return self.pydot
# ----------------------------------------------------------------------------------
# Methods for text printed inside GO terms
def _get_node_text(self, goid, goobj):
"""Return a string to be printed in a GO term box."""
txt = []
# Header line: "GO:0036464 L04 D06"
txt.append(self.pltvars.fmthdr.format(
GO=goobj.id.replace("GO:", "GO"),
level=goobj.level,
depth=goobj.depth))
# GO name line: "cytoplamic ribonucleoprotein"
name = goobj.name.replace(",", "\n")
txt.append(name)
# study info line: "24 genes"
study_txt = self._get_study_txt(goid)
if study_txt is not None:
txt.append(study_txt)
# return text string
return "\n".join(txt)
def _get_study_txt(self, goid):
"""Get GO text from GOEA study."""
if self.go2res is not None:
res = self.go2res.get(goid, None)
if res is not None:
if self.study_items is not None:
return self._get_item_str(res)
else:
return self.pltvars.fmtres.format(
study_count=res.study_count)
def _get_item_str(self, res):
"""Return genes in any of these formats:
1. 19264, 17319, 12520, 12043, 74131, 22163, 12575
2. Ptprc, Mif, Cd81, Bcl2, Sash3, Tnfrsf4, Cdkn1a
3. 7: Ptprc, Mif, Cd81, Bcl2, Sash3...
"""
npl = self.pltvars.items_p_line # Number of items Per Line
prt_items = sorted([self.__get_genestr(itemid) for itemid in res.study_items])
prt_multiline = [prt_items[i:i+npl] for i in range(0, len(prt_items), npl)]
num_items = len(prt_items)
if self.study_items_max is None:
genestr = "\n".join([", ".join(str(e) for e in sublist) for sublist in prt_multiline])
return "{N}) {GENES}".format(N=num_items, GENES=genestr)
else:
if num_items <= self.study_items_max:
strs = [", ".join(str(e) for e in sublist) for sublist in prt_multiline]
genestr = "\n".join([", ".join(str(e) for e in sublist) for sublist in prt_multiline])
return genestr
else:
short_list = prt_items[:self.study_items_max]
short_mult = [short_list[i:i+npl] for i in range(0, len(short_list), npl)]
short_str = "\n".join([", ".join(str(e) for e in sublist) for sublist in short_mult])
return "".join(["{N} genes; ".format(N=num_items), short_str, "..."])
def __get_genestr(self, itemid):
"""Given a geneid, return the string geneid or a gene symbol."""
if self.id2symbol is not None:
symbol = self.id2symbol.get(itemid, None)
if symbol is not None:
return symbol
if isinstance(itemid, int):
return str(itemid)
return itemid
# Copyright (C) 2016-2018, DV Klopfenstein, H Tang, All rights reserved.