From 3ba15087da9d41e33cc0fe6937d2714e29ce050a Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 12 Nov 2023 00:36:18 +0100 Subject: [PATCH] Implement sparse COO features * compression with duplicate entries * COO-COO addition / subtraction --- phiml/math/_sparse.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 07d89f3b..21646e8b 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -102,6 +102,7 @@ def __init__(self, indices: Tensor, values: Tensor, dense_shape: Shape, can_cont assert 'vector' in indices.shape, f"indices must have a vector dimension but got {indices.shape}" assert set(indices.vector.item_names) == set(dense_shape.names), f"The 'vector' dimension of indices must list the dense dimensions {dense_shape} as item names but got {indices.vector.item_names}" assert indices.dtype.kind == int, f"indices must have dtype=int but got {indices.dtype}" + assert instance(values) in instance(indices), f"All instance dimensions of values must exist in indices. values={values.shape}, indices={indices.shape}" self._shape = merge_shapes(dense_shape, batch(indices), non_instance(values)) self._dense_shape = dense_shape self._indices = indices @@ -224,14 +225,31 @@ def compress(self, dims: DimFilter): c_dims = self._shape.only(dims, reorder=True) u_dims = self._dense_shape.without(c_dims) c_idx_packed, u_idx_packed = self._pack_indices(c_dims, u_dims) + values = self._values + if self._can_contain_double_entries: + bi = self._indices.default_backend + assert c_idx_packed.shape[0] == 1, f"sparse compress() not supported for batched indices" + lin_idx = bi.ravel_multi_index(bi.stack([c_idx_packed, u_idx_packed], -1)[0], (c_dims.volume, u_dims.volume)) + u_idx, u_ptr = bi.unique(lin_idx, return_inverse=True, return_counts=False, axis=-1) + num_entries = u_idx.shape[-1] + if num_entries < instance(self._values).size: + b = self.default_backend + values = b.bincount(u_ptr, weights=self._values.native(), bins=num_entries) + values = wrap(values, instance(self._values).without_sizes()) + idx_packed = bi.unravel_index(u_idx, (c_dims.volume, u_dims.volume)) + c_idx_packed = idx_packed[None, :, 0] + u_idx_packed = idx_packed[None, :, 1] # --- Use scipy.sparse.csr_matrix to reorder values --- idx = np.arange(1, c_idx_packed.shape[-1] + 1) # start indexing at 1 since 0 might get removed scipy_csr = scipy.sparse.csr_matrix((idx, (c_idx_packed[0], u_idx_packed[0])), shape=(c_dims.volume, u_dims.volume)) assert c_idx_packed.shape[1] == len(scipy_csr.data), "Failed to create CSR matrix because the CSR matrix contains fewer non-zero values than COO. This can happen when the `x` tensor is too small for the stencil." # --- Construct CompressedSparseMatrix --- - entries_dim = instance(self._values).name - perm = {entries_dim: wrap(scipy_csr.data - 1, instance(entries_dim))} - values = self._values[perm] # Change order accordingly + entries_dim = instance(values).name + if np.any(scipy_csr.data != idx): + perm = {entries_dim: wrap(scipy_csr.data - 1, instance(entries_dim))} + values = values[perm] # Change order accordingly + else: + perm = None indices = wrap(scipy_csr.indices, instance(entries_dim)) pointers = wrap(scipy_csr.indptr, instance('pointers')) return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._default, uncompressed_indices=self._indices, uncompressed_indices_perm=perm) @@ -274,13 +292,16 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st return self._with_values(operator(self._values, other._values)) else: assert op_name in ['add', 'radd', 'sub', 'rsub'] - indices = concat([self._indices, other._indices], instance(self._indices)) + other_indices = rename_dims(other._indices, instance, instance(self._indices)) + indices = concat([self._indices, other_indices], instance(self._indices)) + self_values = expand(self._values, instance(self._indices)) + other_values = rename_dims(expand(other._values, instance(other._indices)), instance, instance(self_values)) if op_symbol == '+': - values = concat([self._values, other._values], instance(self._values)) + values = concat([self_values, other_values], instance(self_values)) elif op_name == 'sub': - values = concat([self._values, -other._values], instance(self._values)) + values = concat([self_values, -other_values], instance(self_values)) else: # op_name == 'rsub': - values = concat([-self._values, other._values], instance(self._values)) + values = concat([-self_values, other_values], instance(self_values)) return SparseCoordinateTensor(indices, values, self._dense_shape, can_contain_double_entries=True, indices_sorted=False, default=self._default) else: # other is dense if self._dense_shape in other.shape: # all dims dense -> convert to dense