From 0c04be91f0581cdb0feea30bdadce65b834b8901 Mon Sep 17 00:00:00 2001 From: holl- Date: Tue, 28 Nov 2023 15:41:10 +0100 Subject: [PATCH] Tracing fixes, add unit test --- phiml/backend/_backend.py | 9 ++++++--- phiml/backend/_numpy_backend.py | 4 ++-- phiml/backend/tensorflow/_tf_backend.py | 2 +- phiml/backend/torch/_torch_backend.py | 2 +- phiml/math/_sparse.py | 9 +++++---- phiml/math/_trace.py | 17 +++++++++++++---- tests/commit/math/test__trace.py | 14 +++++++++++++- 7 files changed, 41 insertions(+), 16 deletions(-) diff --git a/phiml/backend/_backend.py b/phiml/backend/_backend.py index 35df8d13..51166e3a 100644 --- a/phiml/backend/_backend.py +++ b/phiml/backend/_backend.py @@ -1251,11 +1251,14 @@ def mul_coo_dense(self, indices, values, shape, dense): """ values, dense = self.auto_cast(values, dense) batch_size, nnz, channel_count = self.staticshape(values) - _, dense_rows, _, dense_cols = self.staticshape(dense) + batch_size_d, dense_rows, channel_count_d, dense_cols = self.staticshape(dense) + assert batch_size_d == batch_size + assert dense_rows == shape[1] + assert channel_count == channel_count_d + assert dense_cols == 1 # not implemented yet dense_formatted = self.reshape(dense, (batch_size, dense_rows, channel_count * dense_cols)) dense_gathered = self.batched_gather_nd(dense_formatted, indices[:, :, 1:2]) base_grid = self.zeros((batch_size, shape[0], channel_count), self.dtype(dense)) - assert dense_cols == 1 result = self.scatter(base_grid, indices[:, :, 0:1], values * dense_gathered, mode='add') return self.reshape(result, (batch_size, shape[0], channel_count, dense_cols)) @@ -1312,7 +1315,7 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: Tuple[int, dense: (batch, dense_rows=sparse_cols, channels, dense_cols) Returns: - (batch, channels, dense_rows=sparse_cols, dense_cols) + (batch, dense_rows=sparse_cols, channels, dense_cols) """ # if not self.supports(Backend.indexed_segment_sum): native_coo_indices = self.csr_to_coo(column_indices, row_pointers) diff --git a/phiml/backend/_numpy_backend.py b/phiml/backend/_numpy_backend.py index 0f2fcf1e..6d350792 100644 --- a/phiml/backend/_numpy_backend.py +++ b/phiml/backend/_numpy_backend.py @@ -457,8 +457,8 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: tuple, dens b_result = [] for c in range(channel_count): mat = csr_matrix((values[b, :, c], column_indices[b], row_pointers[b]), shape=shape) - b_result.append(mat * dense[b, :, c, :]) - result.append(np.stack(b_result)) + b_result.append((mat * dense[b, :, c, :])) + result.append(np.stack(b_result, 1)) return np.stack(result) def csc_matrix(self, column_pointers, row_indices, values, shape: tuple): diff --git a/phiml/backend/tensorflow/_tf_backend.py b/phiml/backend/tensorflow/_tf_backend.py index c1939814..2e3c14ed 100644 --- a/phiml/backend/tensorflow/_tf_backend.py +++ b/phiml/backend/tensorflow/_tf_backend.py @@ -770,7 +770,7 @@ def mul_coo_dense(self, indices, values, shape, dense): b_result.append(tf.sparse.sparse_dense_matmul(matrix, dense[b, :, c, :])) except NotFoundError: # These data types are probably not supported by TensorFlow return Backend.mul_coo_dense(self, indices, values, shape, dense) - result.append(tf.stack(b_result)) + result.append(tf.stack(b_result, 1)) return tf.stack(result) def not_equal(self, x, y): diff --git a/phiml/backend/torch/_torch_backend.py b/phiml/backend/torch/_torch_backend.py index 2c3589aa..93110e9f 100644 --- a/phiml/backend/torch/_torch_backend.py +++ b/phiml/backend/torch/_torch_backend.py @@ -811,7 +811,7 @@ def mul_csr_dense(self, column_indices, row_pointers, values, shape: tuple, dens for c in range(channels): matrix = torch.sparse_csr_tensor(row_pointers[b], column_indices[b], values[b, :, c], shape, device=values.device) b_result.append(torch.sparse.mm(matrix, self.as_tensor(dense[b, :, c, :]))) - result.append(torch.stack(b_result)) + result.append(torch.stack(b_result, 1)) return torch.stack(result) # if channel_count == 1: # matrix = torch.sparse_csr_tensor(row_pointers, column_indices, values[:, :, 0], (batch_size, *shape), device=values.device) diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 6a8a4906..8c264257 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -973,7 +973,7 @@ def dot_compressed_dense(compressed: CompressedSparseMatrix, cdims: Shape, dense rhs_channels = shape(dense).without(ddims).without(channels) dense_native = reshaped_native(dense, [ind_batch, ddims, channels, rhs_channels]) result_native = backend.mul_csr_dense(native_indices, native_pointers, native_values, native_shape, dense_native) - result = reshaped_tensor(result_native, [ind_batch, channels, compressed._compressed_dims, rhs_channels]) + result = reshaped_tensor(result_native, [ind_batch, compressed._compressed_dims, channels, rhs_channels]) return result else: # transposed matrix vector multiplication. This is inefficient raise NotImplementedError("Transposed sparse matrix multiplication not yet implemented") @@ -1220,8 +1220,9 @@ def sparse_gather(matrix: Tensor, indices: Tensor): row_indices = matrix._indices[row_dims.name_list] col_indices = matrix._indices[col_dims.name_list] # --- Construct SciPy matrix for efficient slicing --- - np_rows = b.ravel_multi_index(row_indices.numpy('sp_entries,sparse_idx'), row_dims.sizes) - np_cols = b.ravel_multi_index(col_indices.numpy('sp_entries,sparse_idx'), row_dims.sizes) + from phiml.math import reshaped_numpy + np_rows = b.ravel_multi_index(reshaped_numpy(row_indices, [instance, channel]), row_dims.sizes) + np_cols = b.ravel_multi_index(reshaped_numpy(col_indices, [instance, channel]), col_dims.sizes) scipy_mat = csr_matrix((placeholders, (np_rows, np_cols)), shape=(row_dims.volume, col_dims.volume)) if channel(indices).size > 1 or row_dims.rank > 1: raise NotImplementedError # ravel indices @@ -1229,7 +1230,7 @@ def sparse_gather(matrix: Tensor, indices: Tensor): lin_indices = unstack(indices, channel)[0].numpy() row_counts = scipy_mat.getnnz(axis=1) # how many elements per matrix row lookup = scipy_mat[lin_indices, :].data - 1 - lookup = expand(wrap(lookup, instance('sp_entries')), channel(sparse_idx='sp_entries')) + lookup = expand(wrap(lookup, instance('sp_entries')), channel(sparse_idx=instance(col_indices).name)) # --- Perform resulting gather on tensors --- gathered_cols = col_indices[lookup] row_count_out = row_counts[lin_indices] # row count for each i in indices diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index bfecc62c..0f915691 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -183,11 +183,17 @@ def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tenso return to_gather_tracer(self).matmul(self_dims, matrix, matrix_dims) raise NotImplementedError + def _upgrade_gather(self): + if len(self.val) > 1 or next(iter(self.val)): + return to_sparse_tracer(self, None) + else: + return to_gather_tracer(self) + def _gather(self, indices: Tensor) -> Tensor: - return to_gather_tracer(self)._gather(indices) + return self._upgrade_gather()._gather(indices) def _scatter(self, base: Tensor, indices: Tensor) -> Tensor: - return to_gather_tracer(self)._scatter(base, indices) + return self._upgrade_gather()._scatter(base, indices) def simplify_add(val: dict) -> Dict[Shape, Tensor]: @@ -425,7 +431,8 @@ def _op2(self, other, assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" bias = operator(self._bias, other._bias) matrix = operator(self._matrix, other._matrix) # ToDo if other has no dependence on vector, it would also be in the output - return SparseLinTracer(self._source, matrix, bias, self._shape) + shape = self._shape & other._shape + return SparseLinTracer(self._source, matrix, bias, shape) else: # other = self._tensor(other) if op_symbol in '*/': @@ -750,7 +757,9 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer: if isinstance(tracer, SparseLinTracer): return tracer if isinstance(tracer, ShiftLinTracer): - tracer = to_gather_tracer(tracer) + matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False) + matrix = rename_dims(matrix, dual, [n + '_src' for n in dual(matrix).as_batch().names]) + return SparseLinTracer(tracer._source, matrix, bias, tracer.shape) assert isinstance(tracer, GatherLinTracer) if tracer._selection is None: if ref is not None: diff --git a/tests/commit/math/test__trace.py b/tests/commit/math/test__trace.py index c1f88b9f..777efe4a 100644 --- a/tests/commit/math/test__trace.py +++ b/tests/commit/math/test__trace.py @@ -2,7 +2,7 @@ from phiml import math from phiml.backend._backend import init_installed_backends -from phiml.math import expand, spatial, non_dual, extrapolation +from phiml.math import expand, spatial, non_dual, extrapolation, vec, wrap, batch BACKENDS = init_installed_backends() @@ -23,3 +23,15 @@ def diagonal(x): if math.get_format(matrix) != 'dense': matrix = matrix.compress(non_dual) math.assert_close(f(x), matrix @ x) + + def test_matrix_from_function_sparse(self): + def lin(x): + l, r = math.shift(x, (0, 1), padding=None, stack_dim=None) + x = l + r + y = x[vec(x=[1, 0])] + y_b = y * wrap([1, -2], batch('b')) + y = y + y_b + return y + test_x = wrap([3, 4, 5], spatial('x')) + matrix, bias = math.matrix_from_function(lin, test_x) + math.assert_close(lin(test_x), matrix @ test_x + bias)