Skip to content

Commit

Permalink
Fix PyTorch traces for variable-size buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 4, 2025
1 parent 8da3ae3 commit 71f73dd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,9 @@ def _native_csr_components(self, invalid='clamp', get_values=True):
assert invalid in ['clamp', 'discard', 'keep']
ind_batch = batch(self._indices) & batch(self._pointers)
channels = non_instance(self._values).without(ind_batch)
native_indices = self._indices.native([ind_batch, instance])
native_pointers = self._pointers.native([ind_batch, instance])
native_values = self._values.native([ind_batch, instance, channels]) if get_values else None
native_indices = self._indices._reshaped_native([ind_batch, instance(self._indices).without_sizes()]) # allow variable instance size (PyTorch tracing)
native_pointers = self._pointers._reshaped_native([ind_batch, instance(self._pointers)])
native_values = self._values._reshaped_native([ind_batch, instance(self._values).without_sizes(), channels]) if get_values else None
native_shape = self._compressed_dims.volume, self._uncompressed_dims.volume
if self._uncompressed_offset is not None:
native_indices -= self._uncompressed_offset
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ def __init__(self, native_tensor, names: Sequence[str], expanded_shape: Shape, b
for s_dim in dim.size.shape.names:
assert s_dim in expanded_shape.names, f"Dimension {dim} varies along {s_dim} but {s_dim} is not part of the Shape {self}"
assert choose_backend(native_tensor) == backend
assert expanded_shape.is_uniform
assert expanded_shape.is_uniform, expanded_shape
shape_sizes = [expanded_shape.get_size(n) for n in names]
assert backend.staticshape(native_tensor) == tuple(shape_sizes), f"Shape {expanded_shape} at {names} does not match native tensor with shape {backend.staticshape(native_tensor)}"

Expand Down Expand Up @@ -1264,7 +1264,7 @@ def _reshaped_native(self, groups: Sequence[Shape]):
native = self._backend.transpose(self._native, perm) # this will cast automatically
native = native[tuple(slices)]
native = self._backend.tile(native, tile)
native = self._backend.reshape(native, [g.volume for g in groups])
native = self._backend.reshape(native, [g.volume if g.well_defined else -1 for g in groups])
return native

def _transposed_native(self, order: Sequence[str], force_expand: bool):
Expand Down

0 comments on commit 71f73dd

Please sign in to comment.