Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 16, 2023
1 parent ddb4892 commit 510fdb8
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
indices = concat([self_indices, other_indices], 'sp_entries')
self_values = pack_dims(self._values, instance(self._indices), instance('sp_entries'))
other_values = pack_dims(other._values, instance(other._indices), instance('sp_entries'))
self_indices, self_values = duplicate_indices_for_batch_dim(self_indices, all_sparse_dims)
other_indices, other_values = duplicate_indices_for_batch_dim(other_indices, all_sparse_dims)
self_indices, self_values = with_sparsified_dim(self_indices, all_sparse_dims)
other_indices, other_values = with_sparsified_dim(other_indices, all_sparse_dims)
if op_symbol == '+':
values = concat([self_values, other_values], instance(self_values))
elif op_name == 'sub':
Expand Down Expand Up @@ -1085,16 +1085,26 @@ def sparsify_batch_dims(matrix: Tensor, value_dims: Shape):


def with_sparsified_dim(matrix: SparseCoordinateTensor, dims: Shape):
indices = matrix._indices
if indices.sparse_idx.item_names == dims.names:
return indices
missing = dims.without(indices.sparse_idx.item_names)
if not missing:
return indices[dims.name_list] # reorder components
offsets = [wrap([*idx.values()] + [0] * non_instance(matrix._indices).volume + [*idx.values()], channel(indices)) for idx in missing.meshgrid()]
offsets = stack(offsets, dims.as_instance())
indices = concat([arange(channel(sparse_idx=dims._more_dual())), matrix._indices, arange(channel(sparse_idx=dims))], 'sparse_idx', expand_values=True)
indices += offsets
return pack_dims(indices, instance, instance('sp_entries'))
components = []
for dim in dims:
if dim.name in indices.sparse_idx.item_names:
components.append(indices[[dim.name]])
else:
from ._ops import meshgrid
components.append(meshgrid(dim, stack_dim=channel('sparse_idx')))
indices = concat(components, 'sparse_idx', expand_values=True)
values = expand(matrix._values, dims)
entries_dims = instance(indices)
indices = pack_dims(indices, entries_dims, instance('sp_entries'))
values = pack_dims(values, entries_dims, instance('sp_entries'))
dense_shape = ...
return SparseCoordinateTensor(indices, values, dense_shape, matrix._can_contain_double_entries, False, matrix._default)


def sparse_sum(value: Tensor, dims: Shape):
Expand Down

0 comments on commit 510fdb8

Please sign in to comment.