Skip to content

Commit

Permalink
allow custom inputs for verify_module and infer shapes for all module…
Browse files Browse the repository at this point in the history
…s in onnx compile
  • Loading branch information
LPanosTT committed Nov 13, 2024
1 parent 5e8317d commit c6b3d07
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 45 deletions.
69 changes: 41 additions & 28 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
def forward(self, x):
return torch.abs(x)

verify_module(Basic(), [(256, 256)])
verify_module(Basic(), input_shapes=[(256, 256)])


def test_add():
Expand All @@ -29,7 +29,7 @@ def __init__(self):
def forward(self, x, y):
return torch.add(x, y)

verify_module(Basic(), [(256, 256)] * 2)
verify_module(Basic(), input_shapes=[(256, 256)] * 2)


def test_concat_dim0():
Expand All @@ -40,7 +40,7 @@ def __init__(self):
def forward(self, x, y):
return torch.cat((x, y), dim=0)

verify_module(Basic(), [(32, 32), (64, 32)])
verify_module(Basic(), input_shapes=[(32, 32), (64, 32)])


def test_concat_dim1():
Expand All @@ -51,7 +51,7 @@ def __init__(self):
def forward(self, x, y):
return torch.cat((x, y), dim=1)

verify_module(Basic(), [(32, 32), (32, 64)])
verify_module(Basic(), input_shapes=[(32, 32), (32, 64)])


def test_concat_dim2():
Expand All @@ -62,7 +62,7 @@ def __init__(self):
def forward(self, x, y):
return torch.cat((x, y), dim=2)

verify_module(Basic(), [(32, 32, 32), (32, 32, 64)])
verify_module(Basic(), input_shapes=[(32, 32, 32), (32, 32, 64)])


def test_concat_dim3():
Expand All @@ -73,7 +73,7 @@ def __init__(self):
def forward(self, x, y):
return torch.cat((x, y), dim=3)

verify_module(Basic(), [(32, 32, 32, 32), (32, 32, 32, 64)])
verify_module(Basic(), input_shapes=[(32, 32, 32, 32), (32, 32, 32, 64)])


@pytest.mark.skip(
Expand All @@ -87,7 +87,7 @@ def __init__(self):
def forward(self, x):
return torch.tensor([1.0, 1.0, 1.0, 1.0])

verify_module(Basic(), [(1, 1)])
verify_module(Basic(), input_shapes=[(1, 1)])


def test_convert():
Expand All @@ -105,11 +105,18 @@ def __init__(self):
def forward(self, x):
return x.to(torch.int32)

verify_module(Basic_toFloat(), [(4, 4)], input_data_types=[torch.int32])
verify_module(Basic_toFloat(), [(4, 4)], input_data_types=[torch.float32])
verify_module(Basic_toInt(), [(4, 4)], input_data_types=[torch.int32])
verify_module(
Basic_toInt(), [(4, 4)], input_data_types=[torch.float32], input_range=(0, 60)
Basic_toFloat(), input_shapes=[(4, 4)], input_data_types=[torch.int32]
)
verify_module(
Basic_toFloat(), input_shapes=[(4, 4)], input_data_types=[torch.float32]
)
verify_module(Basic_toInt(), input_shapes=[(4, 4)], input_data_types=[torch.int32])
verify_module(
Basic_toInt(),
input_shapes=[(4, 4)],
input_data_types=[torch.float32],
input_range=(0, 60),
)


Expand All @@ -121,7 +128,7 @@ def __init__(self):
def forward(self, x, y):
return x / y

verify_module(Basic(), [(2, 2), (2, 2)], required_atol=5e-2)
verify_module(Basic(), input_shapes=[(2, 2), (2, 2)], required_atol=5e-2)


def test_exp():
Expand All @@ -132,7 +139,7 @@ def __init__(self):
def forward(self, x):
return torch.exp(x)

verify_module(Basic(), [(2, 2)], required_atol=3e-2)
verify_module(Basic(), input_shapes=[(2, 2)], required_atol=3e-2)


def test_linear():
Expand All @@ -147,7 +154,7 @@ def forward(self, x):
x = self.linear_b(x)
return x

verify_module(Basic(), [(32, 32)])
verify_module(Basic(), input_shapes=[(32, 32)])


from torch_mlir import fx
Expand All @@ -166,7 +173,7 @@ def forward(self, x):
x = self.linear_a(x)
return x

verify_module(Basic(), [(32, 32)])
verify_module(Basic(), input_shapes=[(32, 32)])


def test_maximum():
Expand All @@ -177,7 +184,7 @@ def __init__(self):
def forward(self, x, y):
return torch.maximum(x, y)

verify_module(Basic(), [(32, 32), (32, 32)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(32, 32), (32, 32)], input_range=(-6, 6))


def test_multiply():
Expand All @@ -188,7 +195,7 @@ def __init__(self):
def forward(self, x, y):
return x * y

verify_module(Basic(), [(32, 32), (32, 32)])
verify_module(Basic(), input_shapes=[(32, 32), (32, 32)])


def test_negate():
Expand All @@ -199,7 +206,7 @@ def __init__(self):
def forward(self, x):
return -x

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(32, 32)], input_range=(-6, 6))


@pytest.mark.skip("keepdim=False is not supported")
Expand All @@ -211,7 +218,7 @@ def __init__(self):
def forward(self, x):
return torch.max(x)

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(32, 32)], input_range=(-6, 6))


@pytest.mark.skip("keepdim=False is not supported")
Expand All @@ -223,7 +230,7 @@ def __init__(self):
def forward(self, x):
return torch.sum(x)

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(32, 32)], input_range=(-6, 6))


def test_relu():
Expand All @@ -236,7 +243,7 @@ def __init__(self):
def forward(self, x):
return torch.relu(x)

verify_module(Basic(), [(32, 32)])
verify_module(Basic(), input_shapes=[(32, 32)])


def test_rsqrt():
Expand All @@ -247,7 +254,9 @@ def __init__(self):
def forward(self, x):
return torch.rsqrt(x)

verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1))
verify_module(
Basic(), input_shapes=[(32, 32)], required_atol=3e-2, input_range=(0.1, 1)
)


def test_sqrt():
Expand All @@ -258,7 +267,9 @@ def __init__(self):
def forward(self, x):
return torch.sqrt(x)

verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1))
verify_module(
Basic(), input_shapes=[(32, 32)], required_atol=3e-2, input_range=(0.1, 1)
)


dim0_cases = []
Expand Down Expand Up @@ -306,7 +317,7 @@ def forward(self, a):

shape = [10, 10, 10, 10]
shape[dim] = 128
verify_module(Basic(), [shape])
verify_module(Basic(), input_shapes=[shape])


def test_subtract():
Expand All @@ -317,7 +328,7 @@ def __init__(self):
def forward(self, x, y):
return x - y

verify_module(Basic(), [(32, 32), (32, 32)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(32, 32), (32, 32)], input_range=(-6, 6))


def test_transpose_2d():
Expand All @@ -328,7 +339,7 @@ def __init__(self):
def forward(self, x):
return torch.transpose(x, 0, 1)

verify_module(Basic(), [(4, 8)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(4, 8)], input_range=(-6, 6))


@pytest.mark.skip("TTNN does not support transpose for higher ranks/dimensions.")
Expand All @@ -340,7 +351,7 @@ def __init__(self):
def forward(self, x):
return torch.transpose(x, 0, 1)

verify_module(Basic(), [(4, 8, 4)], input_range=(-6, 6))
verify_module(Basic(), input_shapes=[(4, 8, 4)], input_range=(-6, 6))


def test_multiple_ops():
Expand All @@ -356,4 +367,6 @@ def forward(self, x):

cc = CompilerConfig()
cc.compile_depth = tt_torch.tools.utils.CompileDepth.EXECUTE_OP_BY_OP
verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False)
verify_module(
Basic(), input_shapes=[(256, 256)], compiler_config=cc, do_assert=False
)
3 changes: 3 additions & 0 deletions tt_torch/onnx_compile/onnx_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def compile_onnx(module: onnx.ModelProto):
# Infer onnx shapes incase that information is missing
module = onnx.shape_inference.infer_shapes(module)

context = Context()
torch_dialect.register_dialect(context)
module_info = onnx_importer.ModelInfo(module)
Expand Down
42 changes: 25 additions & 17 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def _verify_torch_module(
mod,
inputs,
input_shapes,
input_data_types,
required_pcc,
Expand All @@ -21,15 +22,17 @@ def _verify_torch_module(
do_assert,
):
tt_mod = torch.compile(mod, backend=backend, options=compiler_config)
if inputs is None:
if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32)
for shape in input_shapes
]

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]
ret = tt_mod(*inputs)
golden = mod(*inputs)

Expand All @@ -51,6 +54,7 @@ def _verify_torch_module(

def _verify_onnx_module(
filename,
inputs,
input_data_types,
required_pcc,
required_atol,
Expand All @@ -61,15 +65,16 @@ def _verify_onnx_module(

sess = InferenceSession(filename)
input_shapes = [nodearg.shape for nodearg in sess.get_inputs()]

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]
if inputs is None:
if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int64)
for shape in input_shapes
]

inputs_dict = {
nodearg.name: input.numpy().astype(np.float32)
Expand Down Expand Up @@ -113,6 +118,7 @@ def _verify_onnx_module(

def verify_module(
mod,
inputs=None,
input_shapes=None,
input_data_types=[torch.float32],
required_pcc=0.99,
Expand All @@ -127,6 +133,7 @@ def verify_module(
), "Verifying a torch module requires that you provide input_shapes"
_verify_torch_module(
mod,
inputs,
input_shapes,
input_data_types,
required_pcc,
Expand All @@ -141,6 +148,7 @@ def verify_module(
), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model"
_verify_onnx_module(
mod,
inputs,
input_data_types,
required_pcc,
required_atol,
Expand Down

0 comments on commit c6b3d07

Please sign in to comment.