Skip to content

Commit

Permalink
[Op] Enforce int64 output shape in CallTIR (tlc-pack#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored and junrushao committed Jan 25, 2023
1 parent d079eee commit e4eb457
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:
for x in shape:
if isinstance(x, int):
shape_array.append(tvm.tir.IntImm("int64", x))
elif isinstance(x, tvm.tir.IntImm):
shape_array.append(x if x.dtype == "int64" else tvm.tir.IntImm("int64", x.value))
elif isinstance(x, PrimExpr):
# TODO: enforce all shapes are i64
# if x.dtype != "int64":
# raise TypeError("Expect int64 dtype for shape")
if x.dtype != "int64":
raise TypeError("Expect int64 dtype for shape")
shape_array.append(x)
else:
raise TypeError("Expect int or PrimExpr for shape")
Expand Down

0 comments on commit e4eb457

Please sign in to comment.