-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprint_tikz.py
192 lines (173 loc) · 8.79 KB
/
print_tikz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Visualize dependency trees in TikZ.
Takes a scores csv file like that generated by main.py,
and makes dependency diagrams for selected sentences.
"""
import glob
import os.path
import pandas as pd
import numpy as np
from argparse import ArgumentParser
from ast import literal_eval
from conll_data import CONLLReader, EXCLUDED_PUNCTUATION, EXCLUDED_PUNCTUATION_UPOS, CONLL_COLS
def is_edge_to_ignore(edge, observation):
# is_d_punct = bool(observation.FORM[edge[1]-1] in EXCLUDED_PUNCTUATION)
is_d_punct = bool(observation.UPOS[edge[1]-1] in EXCLUDED_PUNCTUATION_UPOS)
is_h_root = bool(edge[0] == 0)
return is_d_punct or is_h_root
def make_string_safe(input_string, replace_dict):
'''Make a string safe by replacing all naughty characters
according to replace_dict
'''
new_string = input_string
for naughty_char in replace_dict.keys():
new_string = new_string.replace(naughty_char, replace_dict[naughty_char])
return new_string
def make_tikz_string(
predicted_edges, observation,
label1='', label2='', label3=''):
''' Writes out a tikz dependency TeX file
for comparing predicted_edges and gold_edges'''
gold_edges_list = list(zip(list(map(int, observation.HEAD)),
list(map(int, observation.ID)),
observation.DEPREL))
gold_edge_to_label = {(e[0], e[1]): e[2] for e in gold_edges_list
if not is_edge_to_ignore(e, observation)}
gold_edges_set = {tuple(sorted(e)) for e in gold_edge_to_label.keys()}
# note converting to 1-indexing
predicted_edges_set = {
tuple(sorted((x[0]+1, x[1]+1))) for x in predicted_edges}
correct_edges = list(gold_edges_set.intersection(predicted_edges_set))
incorrect_edges = list(predicted_edges_set.difference(gold_edges_set))
num_correct = len(correct_edges)
num_total = len(gold_edges_set)
uuas = num_correct/float(num_total) if num_total != 0 else np.NaN
# replace non-TeXsafe characters... add as needed
tex_replace = {'$': '\\$', '&': '$\\with$', '%': '\\%',
'~': '\\textasciitilde', '#': '\\#', '|': '{|}'}
# make string
string = "\\begin{dependency}\n\t\\begin{deptext}\n\t\t"
string += "\\& ".join([make_string_safe(w, tex_replace)
for w in observation.FORM]) + " \\\\" + '\n'
string += "\t\\end{deptext}" + '\n'
for i_index, j_index in gold_edge_to_label:
string += f'\t\\depedge{{{i_index}}}{{{j_index}}}{{{gold_edge_to_label[(i_index, j_index)]}}}\n'
for i_index, j_index in correct_edges:
string += f'\t\\depedge[hide label, edge below, edge style={{-, blue, opacity=0.5}}]{{{i_index}}}{{{j_index}}}{{}}\n'
for i_index, j_index in incorrect_edges:
string += f'\t\\depedge[hide label, edge below, edge style={{-, red, opacity=0.5}}]{{{i_index}}}{{{j_index}}}{{}}\n'
string += "\t\\node (R) at (\\matrixref.east) {{}};\n"
string += f"\t\\node (R1) [right of = R] {{\\tiny\\textsf{{{label3}}}}};\n"
string += f"\t\\node (R2) at (R1.north) {{\\tiny\\textsf{{{label2}}}}};\n"
string += f"\t\\node (R3) at (R2.north) {{\\tiny\\textsf{{{label1}}}}};\n"
string += f"\t\\node (R4) at (R1.south) {{\\tiny "
string += f"$ {num_correct}/{num_total} = {uuas*100:.0f}\\% $}};\n"
string += f"\\end{{dependency}}\n"
return string
def write_tikz_files(
outputdir, edges_df, sentence_indices,
edge_type, output_suffix='', info_text='', index_info_text=''):
''' writes TikZ string to outputdir,
a separate file for each sentence index'''
for sentence_index in sentence_indices:
predicted_edges = literal_eval(edges_df.at[sentence_index, edge_type])
tikz_string = make_tikz_string(predicted_edges,
OBSERVATIONS[sentence_index],
label1=index_info_text + ' ' + str(sentence_index),
label2=output_suffix,
label3=info_text)
tikzf = CLI_ARGS.info + '_' + CLI_ARGS.index_info + str(sentence_index) + '_' + output_suffix + ".tikz"
tikzdir = os.path.join(outputdir, tikzf)
print(f'writing tikz to {tikzdir}')
with open(tikzdir, 'w') as fout:
fout.write(f"% dependencies for {OUTPUTDIR}\n")
fout.write(tikz_string)
if __name__ == '__main__':
ARGP = ArgumentParser()
ARGP.add_argument('--sentence_indices', type=int, nargs='+',
help='''sentence indices to plot dependencies for.
enter integer(s)''')
ARGP.add_argument('--input_file',
default='scores.csv',
help='specify path/to/scores.csv')
ARGP.add_argument('--output_dir',
default='',
help='''path to print tikz to.
If none, put in same place as input''')
ARGP.add_argument('--conllx_file',
default='ptb3-wsj-data/ptb3-wsj-dev.conllx',
help='path/to/treebank.conllx: dependency file')
ARGP.add_argument('--info',
default='',
help='model name or other info to print under sentence index')
ARGP.add_argument('--index_info',
default='',
help='treebank name or other info to print before sentence index')
ARGP.add_argument('--edge_types', type=str,
default=['projective.edges.sum'],
nargs='+',
help="""Edge type to plot against the gold.
Chose any subset of
['projective.edges.sum',
'projective.edges.triu',
'projective.edges.tril',
'projective.edges.none',
'nonproj.edges.sum',
'nonproj.edges.triu',
'nonproj.edges.tril',
'nonproj.edges.none'],
or enter 'all' for all.""")
CLI_ARGS = ARGP.parse_args()
OBSERVATIONS = CONLLReader(CONLL_COLS).load_conll_dataset(
CLI_ARGS.conllx_file)
if CLI_ARGS.output_dir == '':
OUTPUTDIR = os.path.dirname(CLI_ARGS.input_file)
else:
OUTPUTDIR = CLI_ARGS.output_dir
EDGES_DF = pd.read_csv(CLI_ARGS.input_file)
if CLI_ARGS.edge_types == ['all']:
print(CLI_ARGS.edge_types)
CLI_ARGS.edge_types = [
'projective.edges.sum', 'projective.edges.triu',
'projective.edges.tril', 'projective.edges.none',
'nonproj.edges.sum', 'nonproj.edges.triu',
'nonproj.edges.tril', 'nonproj.edges.none'
]
for edge_type in CLI_ARGS.edge_types:
edgetype = edge_type.split(".")
label = f'{edgetype[2]}.{edgetype[0]}'
write_tikz_files(OUTPUTDIR, EDGES_DF,
CLI_ARGS.sentence_indices, edge_type,
output_suffix=label,
info_text=CLI_ARGS.info,
index_info_text=CLI_ARGS.index_info)
TEX_FILEPATH = os.path.join(OUTPUTDIR, 'dependencies.tex')
with open(TEX_FILEPATH, mode='w') as tex_file:
print(f'writing TeX to {TEX_FILEPATH}')
tex_file.write(
"\\documentclass[tikz]{standalone}\n"
"\\usepackage{tikz,tikz-dependency}\n"
"\\usepackage{cmll,xeCJK}\n" # for typesetting '&' and CJK resp
"\\setmainfont{Arial Unicode MSn}\n"
"\\setsansfont{Arial Narrow}\n"
"\\setCJKmainfont{Arial Unicode MS}\n"
"\\pgfkeys{%\n/depgraph/edge unit distance=.75ex,%\n"
"%/depgraph/edge horizontal padding=2,%\n"
"/depgraph/reserved/edge style/.style = {\n->,% arrow properties\n"
"semithick, solid, line cap=round, % line properties\n"
"rounded corners=2,% make corners round\n},%\n"
"/depgraph/reserved/label style/.style = {font=\sffamily,\n"
"% anchor = mid, draw, solid, black, rotate = 0,"
"rounded corners = 2pt,%\nscale = .5,%\ntext height = 1.5ex,"
"text depth = 0.25ex,% needed to center text vertically\n"
"inner sep=.2ex,%\nouter sep = 0pt,%\ntext = black,%\n"
"fill = white,% opacity = 0, text opacity = 0 "
"% uncomment to hide all labels\n},%\n}\n"
"\\begin{document}\n\n% % Put tikz dependencies here, like\n"
)
tex_file.write(f"% dependencies for {OUTPUTDIR}\n")
TIKZFILES = glob.glob(os.path.join(OUTPUTDIR, '*.tikz'))
TIKZFILES = [os.path.basename(x) for x in TIKZFILES]
for tikzfile in sorted(TIKZFILES):
tex_file.write(f"\\input{{{tikzfile}}}\n")
tex_file.write("\n\\end{document}")