Skip to content

Commit

Permalink
Tracing fixes, add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 28, 2023
1 parent 2b1bcba commit 0c04be9
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 16 deletions.
9 changes: 6 additions & 3 deletions phiml/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,11 +1251,14 @@ def mul_coo_dense(self, indices, values, shape, dense):
"""
values, dense = self.auto_cast(values, dense)
batch_size, nnz, channel_count = self.staticshape(values)
_, dense_rows, _, dense_cols = self.staticshape(dense)
batch_size_d, dense_rows, channel_count_d, dense_cols = self.staticshape(dense)
assert batch_size_d == batch_size
assert dense_rows == shape[1]
assert channel_count == channel_count_d
assert dense_cols == 1 # not implemented yet
dense_formatted = self.reshape(dense, (batch_size, dense_rows, channel_count * dense_cols))
dense_gathered = self.batched_gather_nd(dense_formatted, indices[:, :, 1:2])
base_grid = self.zeros((batch_size, shape[0], channel_count), self.dtype(dense))
assert dense_cols == 1
result = self.scatter(base_grid, indices[:, :, 0:1], values * dense_gathered, mode='add')
return self.reshape(result, (batch_size, shape[0], channel_count, dense_cols))

Expand Down Expand Up @@ -1312,7 +1315,7 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: Tuple[int,
dense: (batch, dense_rows=sparse_cols, channels, dense_cols)
Returns:
(batch, channels, dense_rows=sparse_cols, dense_cols)
(batch, dense_rows=sparse_cols, channels, dense_cols)
"""
# if not self.supports(Backend.indexed_segment_sum):
native_coo_indices = self.csr_to_coo(column_indices, row_pointers)
Expand Down
4 changes: 2 additions & 2 deletions phiml/backend/_numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: tuple, dens
b_result = []
for c in range(channel_count):
mat = csr_matrix((values[b, :, c], column_indices[b], row_pointers[b]), shape=shape)
b_result.append(mat * dense[b, :, c, :])
result.append(np.stack(b_result))
b_result.append((mat * dense[b, :, c, :]))
result.append(np.stack(b_result, 1))
return np.stack(result)

def csc_matrix(self, column_pointers, row_indices, values, shape: tuple):
Expand Down
2 changes: 1 addition & 1 deletion phiml/backend/tensorflow/_tf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def mul_coo_dense(self, indices, values, shape, dense):
b_result.append(tf.sparse.sparse_dense_matmul(matrix, dense[b, :, c, :]))
except NotFoundError: # These data types are probably not supported by TensorFlow
return Backend.mul_coo_dense(self, indices, values, shape, dense)
result.append(tf.stack(b_result))
result.append(tf.stack(b_result, 1))
return tf.stack(result)

def not_equal(self, x, y):
Expand Down
2 changes: 1 addition & 1 deletion phiml/backend/torch/_torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: tuple, dens
for c in range(channels):
matrix = torch.sparse_csr_tensor(row_pointers[b], column_indices[b], values[b, :, c], shape, device=values.device)
b_result.append(torch.sparse.mm(matrix, self.as_tensor(dense[b, :, c, :])))
result.append(torch.stack(b_result))
result.append(torch.stack(b_result, 1))
return torch.stack(result)
# if channel_count == 1:
# matrix = torch.sparse_csr_tensor(row_pointers, column_indices, values[:, :, 0], (batch_size, *shape), device=values.device)
Expand Down
9 changes: 5 additions & 4 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def dot_compressed_dense(compressed: CompressedSparseMatrix, cdims: Shape, dense
rhs_channels = shape(dense).without(ddims).without(channels)
dense_native = reshaped_native(dense, [ind_batch, ddims, channels, rhs_channels])
result_native = backend.mul_csr_dense(native_indices, native_pointers, native_values, native_shape, dense_native)
result = reshaped_tensor(result_native, [ind_batch, channels, compressed._compressed_dims, rhs_channels])
result = reshaped_tensor(result_native, [ind_batch, compressed._compressed_dims, channels, rhs_channels])
return result
else: # transposed matrix vector multiplication. This is inefficient
raise NotImplementedError("Transposed sparse matrix multiplication not yet implemented")
Expand Down Expand Up @@ -1220,16 +1220,17 @@ def sparse_gather(matrix: Tensor, indices: Tensor):
row_indices = matrix._indices[row_dims.name_list]
col_indices = matrix._indices[col_dims.name_list]
# --- Construct SciPy matrix for efficient slicing ---
np_rows = b.ravel_multi_index(row_indices.numpy('sp_entries,sparse_idx'), row_dims.sizes)
np_cols = b.ravel_multi_index(col_indices.numpy('sp_entries,sparse_idx'), row_dims.sizes)
from phiml.math import reshaped_numpy
np_rows = b.ravel_multi_index(reshaped_numpy(row_indices, [instance, channel]), row_dims.sizes)
np_cols = b.ravel_multi_index(reshaped_numpy(col_indices, [instance, channel]), col_dims.sizes)
scipy_mat = csr_matrix((placeholders, (np_rows, np_cols)), shape=(row_dims.volume, col_dims.volume))
if channel(indices).size > 1 or row_dims.rank > 1:
raise NotImplementedError # ravel indices
else:
lin_indices = unstack(indices, channel)[0].numpy()
row_counts = scipy_mat.getnnz(axis=1) # how many elements per matrix row
lookup = scipy_mat[lin_indices, :].data - 1
lookup = expand(wrap(lookup, instance('sp_entries')), channel(sparse_idx='sp_entries'))
lookup = expand(wrap(lookup, instance('sp_entries')), channel(sparse_idx=instance(col_indices).name))
# --- Perform resulting gather on tensors ---
gathered_cols = col_indices[lookup]
row_count_out = row_counts[lin_indices] # row count for each i in indices
Expand Down
17 changes: 13 additions & 4 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,17 @@ def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tenso
return to_gather_tracer(self).matmul(self_dims, matrix, matrix_dims)
raise NotImplementedError

def _upgrade_gather(self):
if len(self.val) > 1 or next(iter(self.val)):
return to_sparse_tracer(self, None)
else:
return to_gather_tracer(self)

def _gather(self, indices: Tensor) -> Tensor:
return to_gather_tracer(self)._gather(indices)
return self._upgrade_gather()._gather(indices)

def _scatter(self, base: Tensor, indices: Tensor) -> Tensor:
return to_gather_tracer(self)._scatter(base, indices)
return self._upgrade_gather()._scatter(base, indices)


def simplify_add(val: dict) -> Dict[Shape, Tensor]:
Expand Down Expand Up @@ -425,7 +431,8 @@ def _op2(self, other,
assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix"
bias = operator(self._bias, other._bias)
matrix = operator(self._matrix, other._matrix) # ToDo if other has no dependence on vector, it would also be in the output
return SparseLinTracer(self._source, matrix, bias, self._shape)
shape = self._shape & other._shape
return SparseLinTracer(self._source, matrix, bias, shape)
else:
# other = self._tensor(other)
if op_symbol in '*/':
Expand Down Expand Up @@ -750,7 +757,9 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer:
if isinstance(tracer, SparseLinTracer):
return tracer
if isinstance(tracer, ShiftLinTracer):
tracer = to_gather_tracer(tracer)
matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False)
matrix = rename_dims(matrix, dual, [n + '_src' for n in dual(matrix).as_batch().names])
return SparseLinTracer(tracer._source, matrix, bias, tracer.shape)
assert isinstance(tracer, GatherLinTracer)
if tracer._selection is None:
if ref is not None:
Expand Down
14 changes: 13 additions & 1 deletion tests/commit/math/test__trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from phiml import math
from phiml.backend._backend import init_installed_backends
from phiml.math import expand, spatial, non_dual, extrapolation
from phiml.math import expand, spatial, non_dual, extrapolation, vec, wrap, batch

BACKENDS = init_installed_backends()

Expand All @@ -23,3 +23,15 @@ def diagonal(x):
if math.get_format(matrix) != 'dense':
matrix = matrix.compress(non_dual)
math.assert_close(f(x), matrix @ x)

def test_matrix_from_function_sparse(self):
def lin(x):
l, r = math.shift(x, (0, 1), padding=None, stack_dim=None)
x = l + r
y = x[vec(x=[1, 0])]
y_b = y * wrap([1, -2], batch('b'))
y = y + y_b
return y
test_x = wrap([3, 4, 5], spatial('x'))
matrix, bias = math.matrix_from_function(lin, test_x)
math.assert_close(lin(test_x), matrix @ test_x + bias)

0 comments on commit 0c04be9

Please sign in to comment.