Skip to content

Commit

Permalink
Linear tracing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 12, 2023
1 parent 3ba1508 commit 393122c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 43 deletions.
20 changes: 15 additions & 5 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.sparse

from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, concat_shapes, EMPTY_SHAPE, dual, DUAL_DIM, SPATIAL_DIM, \
non_channel
non_channel, DEBUG_CHECKS
from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim
from ._tensors import Tensor, TensorStack, NativeTensor, cached, wrap
from ..backend import choose_backend, NUMPY, Backend
Expand Down Expand Up @@ -103,13 +103,18 @@ def __init__(self, indices: Tensor, values: Tensor, dense_shape: Shape, can_cont
assert set(indices.vector.item_names) == set(dense_shape.names), f"The 'vector' dimension of indices must list the dense dimensions {dense_shape} as item names but got {indices.vector.item_names}"
assert indices.dtype.kind == int, f"indices must have dtype=int but got {indices.dtype}"
assert instance(values) in instance(indices), f"All instance dimensions of values must exist in indices. values={values.shape}, indices={indices.shape}"
assert set(indices.shape.only(instance(values))) == set(instance(values)), f"indices and values must have equal number of elements but got {instance(indices)} indices and {instance(values)} values"
if not instance(values) and (spatial(values) or dual(values)):
warnings.warn(f"You are creating a sparse tensor with only constant values {values.shape}. To have values vary along indices, add the corresponding instance dimension.", RuntimeWarning, stacklevel=3)
self._shape = merge_shapes(dense_shape, batch(indices), non_instance(values))
self._dense_shape = dense_shape
self._indices = indices
self._values = values
self._can_contain_double_entries = can_contain_double_entries
self._indices_sorted = indices_sorted
self._default = default
if DEBUG_CHECKS:
self.compress_rows()

@property
def shape(self) -> Shape:
Expand Down Expand Up @@ -226,6 +231,7 @@ def compress(self, dims: DimFilter):
u_dims = self._dense_shape.without(c_dims)
c_idx_packed, u_idx_packed = self._pack_indices(c_dims, u_dims)
values = self._values
uncompressed_indices = self._indices
if self._can_contain_double_entries:
bi = self._indices.default_backend
assert c_idx_packed.shape[0] == 1, f"sparse compress() not supported for batched indices"
Expand All @@ -239,20 +245,22 @@ def compress(self, dims: DimFilter):
idx_packed = bi.unravel_index(u_idx, (c_dims.volume, u_dims.volume))
c_idx_packed = idx_packed[None, :, 0]
u_idx_packed = idx_packed[None, :, 1]
uncompressed_indices = bi.unravel_index(u_idx, c_dims.sizes + u_dims.sizes)
uncompressed_indices = wrap(uncompressed_indices, instance(self._indices).without_sizes(), channel(self._indices))
# --- Use scipy.sparse.csr_matrix to reorder values ---
idx = np.arange(1, c_idx_packed.shape[-1] + 1) # start indexing at 1 since 0 might get removed
scipy_csr = scipy.sparse.csr_matrix((idx, (c_idx_packed[0], u_idx_packed[0])), shape=(c_dims.volume, u_dims.volume))
assert c_idx_packed.shape[1] == len(scipy_csr.data), "Failed to create CSR matrix because the CSR matrix contains fewer non-zero values than COO. This can happen when the `x` tensor is too small for the stencil."
# --- Construct CompressedSparseMatrix ---
entries_dim = instance(values).name
entries_dim = instance(self._indices).name
perm = None
values = expand(values, instance(self._indices).without(instance(values)))
if np.any(scipy_csr.data != idx):
perm = {entries_dim: wrap(scipy_csr.data - 1, instance(entries_dim))}
values = values[perm] # Change order accordingly
else:
perm = None
indices = wrap(scipy_csr.indices, instance(entries_dim))
pointers = wrap(scipy_csr.indptr, instance('pointers'))
return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._default, uncompressed_indices=self._indices, uncompressed_indices_perm=perm)
return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._default, uncompressed_indices=uncompressed_indices, uncompressed_indices_perm=perm)

def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor':
dims = self._shape.only(dims)
Expand Down Expand Up @@ -406,6 +414,8 @@ def __init__(self,
assert not channel(pointers) and not spatial(pointers), f"channel and spatial dimensions not allowed on pointers but got {shape(pointers)}"
assert uncompressed_dims.isdisjoint(compressed_dims), f"Dimensions cannot be compressed and uncompressed at the same time but got compressed={compressed_dims}, uncompressed={uncompressed_dims}"
assert instance(pointers).size == compressed_dims.volume + 1
if uncompressed_indices is not None:
assert instance(uncompressed_indices) == instance(indices), f"Number of uncompressed indices {instance(uncompressed_offset)} does not match compressed indices {instance(indices)}"
self._shape = merge_shapes(compressed_dims, uncompressed_dims, batch(indices), batch(pointers), non_instance(values))
self._indices = indices
self._pointers = rename_dims(pointers, instance, 'pointers')
Expand Down
Loading

0 comments on commit 393122c

Please sign in to comment.