forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dependencies.py
116 lines (109 loc) · 4.54 KB
/
test_dependencies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Assert All Dependencies are Importable and Correctly Versioned #
# ---------------------------------------------------------------#
import os
import argparse
import termcolor
import importlib
ERROR = False
ERROR_MSG = '\n'
WARN = False
WARN_MSG = '\n'
PRINT_MSG = '\n'
def parse(str_in):
str_in = str_in.replace('\n', '')
if 'mod_name=' in str_in:
mod_name = str_in.split('mod_name=')[-1].split(' ')[0].split(',')[0]
else:
mod_name = str_in.split('=')[0].split(' ')[0]
if '==' in str_in:
version = str_in.split('==')[-1].split(' ')[0].split(',')[0]
else:
version = None
return mod_name, version
def test_imports(fname, assert_version, update_versions):
global ERROR
global ERROR_MSG
global WARN
global WARN_MSG
global PRINT_MSG
versions_to_update = dict()
msg = '\nasserting imports work for: {}\n\n'.format(fname)
PRINT_MSG += msg
ERROR_MSG += msg
WARN_MSG += msg
with open(fname, 'r') as f:
file_lines = f.readlines()
mod_names_n_versions = [parse(req) for req in file_lines]
for line_num, (mod_name, expected_version) in enumerate(mod_names_n_versions):
# noinspection PyBroadException
try:
mod = importlib.import_module(mod_name)
except Exception as e:
ERROR = True
msg = '{} could not be imported: {}\n'.format(mod_name, e)
ERROR_MSG += msg
PRINT_MSG += msg
continue
# noinspection PyBroadException
try:
# noinspection PyUnresolvedReferences
detected_version = mod.__version__
except AttributeError:
try:
detected_version = '.'.join([str(n) for n in mod.VERSION])
except AttributeError:
continue
except Exception:
detected_version = None
if detected_version and expected_version:
if detected_version == expected_version:
msg = '{} detected correct version: {}\n'.format(mod_name, detected_version)
else:
msg = 'expected version {} for module {}, but detected version {}\n'.format(
expected_version, mod_name, detected_version)
versions_to_update[line_num] = {'expected': expected_version, 'detected': detected_version}
if assert_version:
ERROR = True
ERROR_MSG += msg
else:
WARN = True
WARN_MSG += msg
PRINT_MSG += msg
else:
if detected_version:
msg = '{} detected version: {}, but no expected version provided\n'.format(mod_name, detected_version)
elif expected_version:
msg = '{} expected version: {}, but unable to detect version\n'.format(mod_name, expected_version)
else:
msg = 'no expected version provided, and unable to detect version for {}\n'.format(mod_name)
WARN = True
PRINT_MSG += msg
WARN_MSG += msg
if not update_versions:
return
for line_num, versions in versions_to_update.items():
orig_str = file_lines[line_num]
new_str = orig_str.replace(versions['expected'], versions['detected'])
file_lines[line_num] = new_str
with open(fname, 'w') as f:
f.writelines(file_lines)
def main(filepaths, assert_matching_versions, update_versions):
for filepath in filepaths.replace(' ', '').split(','):
assert os.path.isfile(filepath)
test_imports(filepath, assert_version=assert_matching_versions, update_versions=update_versions)
print(PRINT_MSG)
if WARN:
print(termcolor.colored('WARNING\n' + WARN_MSG, 'red'))
if ERROR:
raise Exception(ERROR_MSG)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-fp', '--filepaths', type=str, required=True,
help='Comma separated filepaths of all text files to check. Spaces are ignored.')
parser.add_argument('-amv', '--assert_matching_versions', action='store_true',
help='Whether to assert that all module versions match those lists in the requirements.txt and'
'optional.txt files.')
parser.add_argument('-uv', '--update_versions', action='store_true',
help='Whether to update the versions in the installation files.')
parsed_args = parser.parse_args()
main(parsed_args.filepaths, parsed_args.assert_matching_versions, parsed_args.update_versions)