Skip to content

Commit

Permalink
Trace dot products, sparse improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 13, 2023
1 parent 15f9620 commit 07c174e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
4 changes: 4 additions & 0 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,10 @@ def dot(x: Tensor,
return y._op2(x, lambda vy, vx: dot(vx, x_dims, vy, y_dims), None, 'dot', '@')
else:
return sparse_dot(x, x_dims, y, y_dims)
if x._is_tracer:
return x._matmul(x_dims, y, y_dims)
if y._is_tracer:
return y._mamul(y_dims, x, x_dims)
x_native = x.native(x.shape)
y_native = y.native(y.shape)
backend = choose_backend(x_native, y_native)
Expand Down
32 changes: 30 additions & 2 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,14 @@ def compress(self, dims: DimFilter):
num_entries = u_idx.shape[-1]
if num_entries < instance(self._values).size:
b = self.default_backend
values = b.bincount(u_ptr, weights=self._values.native(), bins=num_entries)
values = wrap(values, instance(self._values).without_sizes())
if non_instance(self._values):
from ._ops import reshaped_native
batched_values = reshaped_native(self._values, [non_instance, instance])
values = b.batched_bincount(u_ptr[None, :], weights=batched_values, bins=num_entries)
values = wrap(values, non_instance(self._values), instance(self._values).without_sizes())
else:
values = b.bincount(u_ptr, weights=self._values.native(), bins=num_entries)
values = wrap(values, instance(self._values).without_sizes())
idx_packed = bi.unravel_index(u_idx, (c_dims.volume, u_dims.volume))
c_idx_packed = idx_packed[None, :, 0]
u_idx_packed = idx_packed[None, :, 1]
Expand Down Expand Up @@ -1020,3 +1026,25 @@ def sparse_dot(x: Tensor, x_dims: Shape, y: Tensor, y_dims: Shape):
raise NotImplementedError("sparse-sparse multiplication not yet supported")
return dot_coordinate_dense(y, y_dims, x, x_dims)
raise NotImplementedError


def include_value_dims_in_pattern(matrix: Tensor, value_dims: Shape):
"""
Bakes dimensions non-instance dimensions of the non-zero values in `matrix` into the sparsity pattern.
The number of values stays unchanged but the number of indices increases.
"""
assert instance(value_dims).is_empty
if isinstance(matrix, SparseCoordinateTensor):
from ._ops import arange
dims = matrix._values.shape.only(value_dims)
indices = concat([arange(channel(vector=dims.as_dual())), arange(channel(vector=dims)), matrix._indices], 'vector', expand_values=True)
all_indices = [indices]
for idx in dims.meshgrid():
if any(i != 0 for i in idx.values()):
offset = wrap([*idx.values()] * 2 + [0] * non_instance(matrix._indices).volume, channel(indices))
all_indices.append(indices + offset)
indices = concat(all_indices, instance(indices))
values = pack_dims(matrix._values, concat_shapes(dims, instance(matrix._values)), instance(matrix._values))
dense_shape = dims.as_dual() & dims & matrix._dense_shape
return SparseCoordinateTensor(indices, values, dense_shape, matrix._can_contain_double_entries, matrix._indices_sorted, matrix._default)
raise NotImplementedError
33 changes: 23 additions & 10 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from ..backend import choose_backend, NUMPY, Backend
from ._ops import choose_backend_t, concat_tensor, scatter, zeros_like
from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS
from ._magic_ops import stack, expand, rename_dims, unpack_dim
from ._magic_ops import stack, expand, rename_dims, unpack_dim, unstack
from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, discard_constant_dims, variable_shape
from ._sparse import SparseCoordinateTensor, is_sparse, sparse_dims, same_sparsity_pattern, sparse_tensor
from ._sparse import SparseCoordinateTensor, is_sparse, sparse_dims, same_sparsity_pattern, sparse_tensor, include_value_dims_in_pattern
from . import _ops as math
from ..backend._dtype import combine_types

Expand Down Expand Up @@ -175,9 +175,9 @@ def _natives(self) -> tuple:
def _spec_dict(self) -> dict:
raise LinearTraceInProgress(self)

def _matmul(self, matrix: Tensor, sdims: Shape, ddims: Shape):
def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tensor:
if is_sparse(matrix):
return to_gather_tracer(self).matmul(matrix, sdims, ddims)
return to_gather_tracer(self).matmul(self_dims, matrix, matrix_dims)
raise NotImplementedError

def _gather(self, indices: Tensor) -> Tensor:
Expand Down Expand Up @@ -224,11 +224,11 @@ def __init__(self, source: TracerSource, diag, bias: Tensor, shape: Shape, selec
def __repr__(self):
return f"{self.__class__.__name__} {self._shape}"

def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape):
shape = matrix.shape.without(mdims) & self._shape.without(ddims)
def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tensor:
shape = matrix.shape.without(matrix_dims) & self._shape.without(self_dims)
matrix *= self._matrix
matrix = rename_dims(matrix, mdims, rename_dims(ddims, [*self._renamed.keys()], [*self._renamed.values()]).as_dual())
renamed = {n: o for n, o in self._renamed.items() if n not in ddims}
matrix = rename_dims(matrix, matrix_dims, rename_dims(self_dims, [*self._renamed.keys()], [*self._renamed.values()]).as_dual())
renamed = {n: o for n, o in self._renamed.items() if n not in self_dims}
return SparseLinTracer(self._source, matrix, self._bias, shape, self._selection, renamed)

def _gather(self, indices: Tensor):
Expand Down Expand Up @@ -344,8 +344,21 @@ def __repr__(self):
def _get_matrix(self, sparsify_batch: bool):
return self._matrix # batch dims are currently always included in the matrix

def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape):
raise NotImplementedError
def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tensor:
reduced_dims = dependent_out_dims(self).only(self_dims) # these dimensions
new_dims = self._shape.only(self_dims).without(reduced_dims) # these dimensions should be added to the matrix input
shape = self._shape.without(self_dims) & matrix.shape.without(matrix_dims)
if reduced_dims:
raise NotImplementedError
new_sparse = self._matrix * matrix
new_sparse = include_value_dims_in_pattern(new_sparse, new_dims)
reduced_sparse = math.sum_(new_sparse, new_dims)
return SparseLinTracer(self._source, reduced_sparse, self._bias, shape)
# for dim in new_dims:
# matrix_slices = unstack(matrix, matrix_dims)
# new_tracers = [self._matrix * s for s in matrix_slices] # duplicate tracer entries for each slice
# new_tracer = stack(new_tracers, self_dims)
# new_tracer = math.sum_(new_tracer, self_dims)

def _gather(self, indices: Tensor):
"""
Expand Down

0 comments on commit 07c174e

Please sign in to comment.