Skip to content

Commit

Permalink
Fix stored_indices() for compact sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Nov 25, 2024
1 parent 252771e commit e714411
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,8 @@ def to_coo(self):
from ._ops import arange
rows = arange(self._uncompressed_dims)
rows = expand(rows, self._compact_dims)
rows = pack_dims(rows, [*self._compact_dims.names, *self._uncompressed_dims.names], instance('entries'))
cols = pack_dims(self._indices, [*self._compact_dims.names, *self._uncompressed_dims.names], instance('entries'))
rows = pack_dims(rows, [*self._uncompressed_dims.names, *self._compact_dims.names], instance('entries'))
cols = pack_dims(self._indices, [*self._uncompressed_dims.names, *self._compact_dims.names], instance('entries'))
indices = stack([rows, cols], channel(sparse_idx=[*self._uncompressed_dims.names, *self._compressed_dims.names]))
values = pack_dims(self._values, [*self._compact_dims.names, *self._uncompressed_dims.names], instance('entries'))
return SparseCoordinateTensor(indices, values, self._compressed_dims & self._uncompressed_dims, False, True, self._indices_constant, self._matrix_rank)
Expand Down Expand Up @@ -1175,6 +1175,9 @@ def to_format(x: Tensor, format: str):
return x.to_cs()
else:
return to_format(x.to_coo(), format)
elif isinstance(x, TensorStack):
converted = [to_format(t, format) for t in x._tensors]
return TensorStack(converted, x._stack_dim)
else: # dense to sparse
from ._ops import nonzero
indices = nonzero(rename_dims(x, channel, instance))
Expand Down Expand Up @@ -1322,10 +1325,14 @@ def stored_indices(x: Tensor, list_dim=instance('entries'), index_dim=channel('i
if isinstance(x, TensorStack):
if x.is_cached or not x.requires_broadcast:
return stored_indices(cached(x))
raise NotImplementedError
return stack([stored_indices(t, list_dim) for t in x._tensors], x._stack_dim) # ToDo add index for stack dim
if x._stack_dim.batch_rank:
return stack([stored_indices(t, list_dim, index_dim, invalid) for t in x._tensors], x._stack_dim)
raise NotImplementedError # ToDo add index for stack dim
elif isinstance(x, CompressedSparseMatrix):
return rename_dims(x._coo_indices(invalid, stack_dim=index_dim), instance, list_dim)
elif isinstance(x, CompactSparseTensor):
# col = pack_dims(x._indices, x._compressed_dims + x._uncompressed_dims, list_dim)
x = to_format(x, 'coo')
if isinstance(x, SparseCoordinateTensor):
if x._can_contain_double_entries:
warnings.warn(f"stored_values of sparse tensor {x.shape} may contain multiple values for the same position.")
Expand Down

0 comments on commit e714411

Please sign in to comment.