-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathtest_roundtrip.py
89 lines (78 loc) · 4.05 KB
/
test_roundtrip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
""" Tests pyMLIR in a parse->dump->parse round-trip. """
from mlir import parse_string
from mlir.dialects.func import func
def test_toy_roundtrip():
"""
Create MLIR code without extra whitespace and check that it can parse
and dump the same way.
"""
code = '''module {
func.func @toy_func(%arg0: tensor<2x3xf64>) -> tensor<3x2xf64> {
%0 = "toy.transpose"(%arg0) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %0 : tensor<3x2xf64>
}
}'''
module = parse_string(code)
dump = module.dump()
assert dump == code
def test_function_no_args():
"""
Test round-tripping a function with no arguments.
"""
code = '''module {
func.func @toy_func() -> index {
%0 = constant 0 : index
return %0 : index
}
}'''
module = parse_string(code)
dump = module.dump()
assert dump == code
def test_affine_expr_roundtrip():
"""
Create affine maps, semi-affine maps, and integer sets, checking for
correct parsing.
"""
code = '''#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0) -> (d0)
#map2 = () -> (0)
#map3 = () -> (10)
#map4 = (d0, d1, d2) -> (d0, d1 + d2 + 5)
#map5 = (d0, d1, d2) -> (d0 + d1, d2)
#map6 = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
#map7 = (d0, d1)[s0] -> (d0 + s0, d1)
#map8 = (d0, d1) -> (d0 + d1 + 11)
#map9 = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
#map10 = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
#samap0 = (d0)[s0] -> (d0 floordiv (s0 + 1))
#samap1 = (d0)[s0] -> (d0 floordiv s0)
#samap2 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)
#set0 = (d0) : (1 == 0)
#set1 = (d0, d1)[s0] : ()
#set2 = (d0, d1)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0)
#set3 = (d0, d1, d2) : (d0 - d2 * 4 == 0, d0 + d1 * 8 - 9 >= 0, -d0 - d1 * 8 + 11 >= 0)
#set4 = (d0, d1, d2, d3, d4, d5) : (d0 * 1089234 + d1 * 203472 + 82342 >= 0, d0 * -55 + d1 * 24 + d2 * 238 - d3 * 234 - 9743 >= 0, d0 * -5445 - d1 * 284 + d2 * 23 + d3 * 34 - 5943 >= 0, d0 * -5445 + d1 * 284 + d2 * 238 - d3 * 34 >= 0, d0 * 445 + d1 * 284 + d2 * 238 + d3 * 39 >= 0, d0 * -545 + d1 * 214 + d2 * 218 - d3 * 94 >= 0, d0 * 44 - d1 * 184 - d2 * 231 + d3 * 14 >= 0, d0 * -45 + d1 * 284 + d2 * 138 - d3 * 39 >= 0, d0 * 154 - d1 * 84 + d2 * 238 - d3 * 34 >= 0, d0 * 54 - d1 * 284 - d2 * 223 + d3 * 384 >= 0, d0 * -55 + d1 * 284 + d2 * 23 + d3 * 34 >= 0, d0 * 54 - d1 * 84 + d2 * 28 - d3 * 34 >= 0, d0 * 54 - d1 * 24 - d2 * 23 + d3 * 34 >= 0, d0 * -55 + d1 * 24 + d2 * 23 + d3 * 4 >= 0, d0 * 15 - d1 * 84 + d2 * 238 - d3 * 3 >= 0, d0 * 5 - d1 * 24 - d2 * 223 + d3 * 84 >= 0, d0 * -5 + d1 * 284 + d2 * 23 - d3 * 4 >= 0, d0 * 14 + d2 * 4 + 7234 >= 0, d0 * -174 - d2 * 534 + 9834 >= 0, d0 * 194 - d2 * 954 + 9234 >= 0, d0 * 47 - d2 * 534 + 9734 >= 0, d0 * -194 - d2 * 934 + 984 >= 0, d0 * -947 - d2 * 953 + 234 >= 0, d0 * 184 - d2 * 884 + 884 >= 0, d0 * -174 + d2 * 834 + 234 >= 0, d0 * 844 + d2 * 634 + 9874 >= 0, d2 * -797 - d3 * 79 + 257 >= 0, d0 * 2039 + d2 * 793 - d3 * 99 - d4 * 24 + d5 * 234 >= 0, d2 * 78 - d5 * 788 + 257 >= 0, d3 - (d5 + d0 * 97) floordiv 423 >= 0, ((d0 + (d3 mod 5) floordiv 2342) * 234) mod 2309 + (d0 + d3 * 2038) floordiv 208 >= 0, ((((d0 + d3 * 2300) * 239) floordiv 2342) mod 2309) mod 239423 == 0, d0 + d3 mod 2642 + (((((d3 + d0 * 2) mod 1247) mod 2038) mod 2390) mod 2039) floordiv 55 >= 0)
'''
module = parse_string(code)
assert module.dump() == code
def test_loop_dialect_roundtrip():
src = """module {
func.func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>, %C: memref<?xf32>, %result: memref<?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = dim %A , %c0 : memref<?xf32>
%b0 = affine.min affine_map<()[s0, s1] -> (1024, s0 - s1)> ()[%d0, %outer]
scf.for %i0 = %c0 to %b0 step %c1 {
%B_elem = load %B [ %i0 ] : memref<?xf32>
%C_elem = load %C [ %i0 ] : memref<?xf32>
%sum_elem = addf %B_elem , %C_elem : f32
store %sum_elem , %result [ %i0 ] : memref<?xf32>
}
return
}
}"""
assert parse_string(src).dump() == src
if __name__ == '__main__':
test_toy_roundtrip()
test_affine_expr_roundtrip()
test_loop_dialect_roundtrip()