Skip to content

Commit

Permalink
Implement sparse-sparse multiplication
Browse files Browse the repository at this point in the history
* Fix PyTorch conversion errors
* Add unit test
  • Loading branch information
holl- committed Feb 20, 2024
1 parent 1971004 commit a3c27af
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
32 changes: 25 additions & 7 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def compress(self, dims: DimFilter):
uncompressed_indices = bi.unravel_index(u_idx, c_dims.sizes + u_dims.sizes)
uncompressed_indices = wrap(uncompressed_indices, instance('sp_entries'), channel(self._indices))
# --- Use scipy.sparse.csr_matrix to reorder values ---
c_idx_packed = choose_backend(c_idx_packed).numpy(c_idx_packed)
u_idx_packed = choose_backend(u_idx_packed).numpy(u_idx_packed)
idx = np.arange(1, c_idx_packed.shape[-1] + 1) # start indexing at 1 since 0 might get removed
scipy_csr = csr_matrix((idx, (c_idx_packed[0], u_idx_packed[0])), shape=(c_dims.volume, u_dims.volume))
assert c_idx_packed.shape[1] == len(scipy_csr.data), "Failed to create CSR matrix because the CSR matrix contains fewer non-zero values than COO. This can happen when the `x` tensor is too small for the stencil."
Expand Down Expand Up @@ -641,7 +643,7 @@ def _with_shape_replaced(self, new_shape: Shape):
values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape))
indices = self._indices._with_shape_replaced(self._indices.shape.replace(self._shape, new_shape))
pointers = self._pointers._with_shape_replaced(self._pointers.shape.replace(self._shape, new_shape))
uncompressed_indices = self._uncompressed_indices._with_shape_replaced(self._uncompressed_indices.shape.replace(self._shape, new_shape)) if self._uncompressed_indices is not None else None
uncompressed_indices = self._uncompressed_indices._with_shape_replaced(self._uncompressed_indices.shape.replace(self._shape, new_shape, replace_item_names=channel)) if self._uncompressed_indices is not None else None
return CompressedSparseMatrix(indices, pointers, values, self._uncompressed_dims.replace(self._shape, new_shape), self._compressed_dims.replace(self._shape, new_shape), self._default, self._uncompressed_offset, uncompressed_indices, self._uncompressed_indices_perm)

def _native_csr_components(self, invalid='clamp', get_values=True):
Expand Down Expand Up @@ -1026,6 +1028,22 @@ def dot_coordinate_dense(sparse: SparseCoordinateTensor, sdims: Shape, dense: Te
return result


def dot_sparse_sparse(a: Tensor, a_dims: Shape, b: Tensor, b_dims: Shape):
b = to_format(b, 'coo')
assert a_dims.rank == b_dims.rank
remaining_a = sparse_dims(a).without(a_dims)
remaining_b = sparse_dims(b).without(b_dims)
list_dim = instance(b._values)

a_gathered = a[{a_dim.name: b._indices[b_dim.name] for a_dim, b_dim in zip(a_dims, b_dims)}]
values = a_gathered * b._values # for each value in B, we have all
i = values._indices[remaining_a]
j = b._indices[remaining_b][{instance: values._indices[list_dim.name]}]
indices = concat([i, j], 'sparse_idx')
values = values._values
return SparseCoordinateTensor(indices, values, channel(a) & dual(b), can_contain_double_entries=True, indices_sorted=False, default=a._default)


def native_matrix(value: Tensor, target_backend: Backend):
target_backend = target_backend or value.default_backend
cols = dual(value)
Expand Down Expand Up @@ -1078,20 +1096,20 @@ def sparse_dot(x: Tensor, x_dims: Shape, y: Tensor, y_dims: Shape):
if isinstance(x, CompressedSparseMatrix):
if isinstance(y, (CompressedSparseMatrix, SparseCoordinateTensor)):
if x_dims.only(sparse_dims(x)) and y_dims.only(sparse_dims(y)):
raise NotImplementedError("sparse-sparse multiplication not yet supported")
return dot_sparse_sparse(x, x_dims, y, y_dims)
raise NotImplementedError
return dot_compressed_dense(x, x_dims, y, y_dims)
elif isinstance(y, CompressedSparseMatrix):
if isinstance(x, (CompressedSparseMatrix, SparseCoordinateTensor)):
raise NotImplementedError("sparse-sparse multiplication not yet supported")
return dot_sparse_sparse(x, x_dims, y, y_dims)
return dot_compressed_dense(y, y_dims, x, x_dims)
if isinstance(x, SparseCoordinateTensor):
if isinstance(y, (CompressedSparseMatrix, SparseCoordinateTensor)):
raise NotImplementedError("sparse-sparse multiplication not yet supported")
return dot_sparse_sparse(x, x_dims, y, y_dims)
return dot_coordinate_dense(x, x_dims, y, y_dims)
elif isinstance(y, SparseCoordinateTensor):
if isinstance(x, (CompressedSparseMatrix, SparseCoordinateTensor)):
raise NotImplementedError("sparse-sparse multiplication not yet supported")
return dot_sparse_sparse(x, x_dims, y, y_dims)
return dot_coordinate_dense(y, y_dims, x, x_dims)
raise NotImplementedError

Expand Down Expand Up @@ -1267,8 +1285,8 @@ def sparse_gather(matrix: Tensor, indices: Tensor):
col_indices = matrix._indices[col_dims.name_list]
# --- Construct SciPy matrix for efficient slicing ---
from ._ops import reshaped_numpy, reshaped_tensor
np_rows = b.ravel_multi_index(reshaped_numpy(row_indices, [instance, channel]), row_dims.sizes)
np_cols = b.ravel_multi_index(reshaped_numpy(col_indices, [instance, channel]), col_dims.sizes)
np_rows = NUMPY.ravel_multi_index(reshaped_numpy(row_indices, [instance, channel]), row_dims.sizes)
np_cols = NUMPY.ravel_multi_index(reshaped_numpy(col_indices, [instance, channel]), col_dims.sizes)
scipy_mat = csr_matrix((placeholders, (np_rows, np_cols)), shape=(row_dims.volume, col_dims.volume))
if channel(indices).size > 1 or row_dims.rank > 1:
raise NotImplementedError # ravel indices
Expand Down
7 changes: 7 additions & 0 deletions tests/commit/math/test__sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,10 @@ def test_reduce(self):
math.assert_close([1, -1], math.min(matrix, dual)) # the 0 is not part of the matrix anymore
math.assert_close([-1, 2], math.max(matrix, channel))
math.assert_close([1, 2], math.max(matrix, dual))

def test_sparse_sparse_mul(self):
expected = wrap([[9, 1], [-3, 3]], channel('in'), dual('out'))
for format in ['coo', 'csr', 'csc', 'dense']:
a = math.to_format(tensor([[1, 2], [3, 0]], channel('in'), dual('red')), format)
b = math.to_format(tensor([[-1, 1], [5, 0]], channel('red'), dual('out')), format)
math.assert_close(expected, a @ b, msg=format)

0 comments on commit a3c27af

Please sign in to comment.