diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 21646e8b..4f557e2b 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -6,7 +6,7 @@ import scipy.sparse from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, concat_shapes, EMPTY_SHAPE, dual, DUAL_DIM, SPATIAL_DIM, \ - non_channel + non_channel, DEBUG_CHECKS from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim from ._tensors import Tensor, TensorStack, NativeTensor, cached, wrap from ..backend import choose_backend, NUMPY, Backend @@ -103,6 +103,9 @@ def __init__(self, indices: Tensor, values: Tensor, dense_shape: Shape, can_cont assert set(indices.vector.item_names) == set(dense_shape.names), f"The 'vector' dimension of indices must list the dense dimensions {dense_shape} as item names but got {indices.vector.item_names}" assert indices.dtype.kind == int, f"indices must have dtype=int but got {indices.dtype}" assert instance(values) in instance(indices), f"All instance dimensions of values must exist in indices. values={values.shape}, indices={indices.shape}" + assert set(indices.shape.only(instance(values))) == set(instance(values)), f"indices and values must have equal number of elements but got {instance(indices)} indices and {instance(values)} values" + if not instance(values) and (spatial(values) or dual(values)): + warnings.warn(f"You are creating a sparse tensor with only constant values {values.shape}. To have values vary along indices, add the corresponding instance dimension.", RuntimeWarning, stacklevel=3) self._shape = merge_shapes(dense_shape, batch(indices), non_instance(values)) self._dense_shape = dense_shape self._indices = indices @@ -110,6 +113,8 @@ def __init__(self, indices: Tensor, values: Tensor, dense_shape: Shape, can_cont self._can_contain_double_entries = can_contain_double_entries self._indices_sorted = indices_sorted self._default = default + if DEBUG_CHECKS: + self.compress_rows() @property def shape(self) -> Shape: @@ -226,6 +231,7 @@ def compress(self, dims: DimFilter): u_dims = self._dense_shape.without(c_dims) c_idx_packed, u_idx_packed = self._pack_indices(c_dims, u_dims) values = self._values + uncompressed_indices = self._indices if self._can_contain_double_entries: bi = self._indices.default_backend assert c_idx_packed.shape[0] == 1, f"sparse compress() not supported for batched indices" @@ -239,20 +245,22 @@ def compress(self, dims: DimFilter): idx_packed = bi.unravel_index(u_idx, (c_dims.volume, u_dims.volume)) c_idx_packed = idx_packed[None, :, 0] u_idx_packed = idx_packed[None, :, 1] + uncompressed_indices = bi.unravel_index(u_idx, c_dims.sizes + u_dims.sizes) + uncompressed_indices = wrap(uncompressed_indices, instance(self._indices).without_sizes(), channel(self._indices)) # --- Use scipy.sparse.csr_matrix to reorder values --- idx = np.arange(1, c_idx_packed.shape[-1] + 1) # start indexing at 1 since 0 might get removed scipy_csr = scipy.sparse.csr_matrix((idx, (c_idx_packed[0], u_idx_packed[0])), shape=(c_dims.volume, u_dims.volume)) assert c_idx_packed.shape[1] == len(scipy_csr.data), "Failed to create CSR matrix because the CSR matrix contains fewer non-zero values than COO. This can happen when the `x` tensor is too small for the stencil." # --- Construct CompressedSparseMatrix --- - entries_dim = instance(values).name + entries_dim = instance(self._indices).name + perm = None + values = expand(values, instance(self._indices).without(instance(values))) if np.any(scipy_csr.data != idx): perm = {entries_dim: wrap(scipy_csr.data - 1, instance(entries_dim))} values = values[perm] # Change order accordingly - else: - perm = None indices = wrap(scipy_csr.indices, instance(entries_dim)) pointers = wrap(scipy_csr.indptr, instance('pointers')) - return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._default, uncompressed_indices=self._indices, uncompressed_indices_perm=perm) + return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._default, uncompressed_indices=uncompressed_indices, uncompressed_indices_perm=perm) def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': dims = self._shape.only(dims) @@ -406,6 +414,8 @@ def __init__(self, assert not channel(pointers) and not spatial(pointers), f"channel and spatial dimensions not allowed on pointers but got {shape(pointers)}" assert uncompressed_dims.isdisjoint(compressed_dims), f"Dimensions cannot be compressed and uncompressed at the same time but got compressed={compressed_dims}, uncompressed={uncompressed_dims}" assert instance(pointers).size == compressed_dims.volume + 1 + if uncompressed_indices is not None: + assert instance(uncompressed_indices) == instance(indices), f"Number of uncompressed indices {instance(uncompressed_offset)} does not match compressed indices {instance(indices)}" self._shape = merge_shapes(compressed_dims, uncompressed_dims, batch(indices), batch(pointers), non_instance(values)) self._indices = indices self._pointers = rename_dims(pointers, instance, 'pointers') diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 16ad4979..11e15f6e 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -6,7 +6,7 @@ from ..backend import choose_backend, NUMPY, Backend from ._ops import choose_backend_t, concat_tensor, scatter, zeros_like -from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal +from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS from ._magic_ops import stack, expand, rename_dims, unpack_dim from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, discard_constant_dims, variable_shape from ._sparse import SparseCoordinateTensor, is_sparse, sparse_dims, same_sparsity_pattern, sparse_tensor @@ -40,8 +40,9 @@ def __init__(self, source: TracerSource, values_by_shift: dict, shape: Shape, bi assert isinstance(renamed, dict) self._source = source self.val: Dict[Shape, Tensor] = simplify_add(values_by_shift) - for shift_ in self.val.keys(): + for shift_, v in self.val.items(): assert shift_.only(sorted(shift_.names), reorder=True) == shift_ + assert v.shape.only(shape) == shape.only(v.shape), f"Tracer with shape {shape} must have matching values but got {v.shape}" # values must match shape self._bias = bias self._shape = shape self._renamed = renamed # new_name -> old_name @@ -75,7 +76,7 @@ def _getitem(self, selection: dict): new_shape = math.zeros(self._shape)[selection].shape return self.shift(starts, new_shape, lambda v: v[selection], lambda b: b[selection]) - def shift(self, shifts: dict, + def shift(self, shifts: Dict[str, int], new_shape: Shape, val_fun: Callable, bias_fun: Callable = None): @@ -124,10 +125,9 @@ def _op2(self, other: Tensor, operator: Callable, native_function: Callable, op_name: str = 'unknown', - op_symbol: str = '?') -> 'ShiftLinTracer': + op_symbol: str = '?') -> Tensor: if isinstance(other, SparseLinTracer): - sparse_self = self.to_sparse_tracer() - return sparse_self._op2(other, operator, native_function, op_name, op_symbol) + return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol) assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}" zeros_for_missing_self = op_name not in ['add', 'radd', 'rsub'] # perform `operator` where `self == 0` zeros_for_missing_other = op_name not in ['add', 'radd', 'sub'] # perform `operator` where `other == 0` @@ -208,19 +208,22 @@ def __init__(self, source: TracerSource, diag, bias: Tensor, shape: Shape, selec assert bias.shape in shape assert selection is None or selection.dtype.kind == int assert diag.shape in shape + assert selection is None or selection.shape.volume > 0 self._source = source self._diag = diag # full matrix or diagonal elements only self._bias = bias # matches self.shape self._selection = selection # Can index one or multiple dimensions of the source. Must retain source dimensions. self._shape = shape self._renamed = renamed # dims renamed before matrix mul. new_name -> old_name + if DEBUG_CHECKS: + if selection is not None: + assert selection.min >= 0, f"Negative selection indices: {selection}" + for dim in channel(selection).item_names[0]: + assert selection[dim].max < self._source.shape.get_size(dim), f"Too large selection indices for source tensor {self._source.shape}: {selection}" def __repr__(self): return f"{self.__class__.__name__} {self._shape}" - def _get_matrix(self, sparsify_batch: bool): - raise NotImplementedError - def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape): shape = matrix.shape.without(mdims) & self._shape.without(ddims) matrix *= self._matrix @@ -246,7 +249,7 @@ def _gather(self, indices: Tensor): return GatherLinTracer(self._source, diag, bias, shape, indices, renamed) def _scatter(self, base: Tensor, indices: Tensor) -> Tensor: - return to_sparse_tracer(self)._scatter(base, indices) + return to_sparse_tracer(self, None)._scatter(base, indices) def __neg__(self): return GatherLinTracer(self._source, -self._diag, -self._bias, self._shape, self._selection, self._renamed) @@ -270,7 +273,7 @@ def _op2(self, other: Tensor, if isinstance(other, GatherLinTracer): assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" if not math.always_close(self._selection, other._selection): - return to_sparse_tracer(self)._op2(other, operator, native_function, op_name, op_symbol) + return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol) diag = operator(self._diag, other._diag) bias = operator(self._bias, other._bias) return GatherLinTracer(self._source, diag, bias, self._shape, self._selection, self._renamed) @@ -339,9 +342,7 @@ def __repr__(self): return f"{self.__class__.__name__} {self._shape}" def _get_matrix(self, sparsify_batch: bool): - if sparsify_batch: - raise NotImplementedError - return self._matrix + return self._matrix # batch dims are currently always included in the matrix def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape): raise NotImplementedError @@ -386,7 +387,7 @@ def _op2(self, other, other = self._tensor(other) assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}" if other._is_tracer and not isinstance(other, SparseLinTracer): - other = to_sparse_tracer(other) + other = to_sparse_tracer(other, self) if isinstance(other, SparseLinTracer): assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" bias = operator(self._bias, other._bias) @@ -433,14 +434,19 @@ def _natives(self) -> tuple: def concat_tracers(tracers: Sequence[Tensor], dim: str): + full_size = sum([t_.shape.get_size(dim) for t_ in tracers]) + shape = merge_shapes([t.shape.with_dim_size(dim, full_size) for t in tracers]) + tracer_count = len([t for t in tracers if t._is_tracer]) if any(isinstance(t, SparseLinTracer) for t in tracers): - tracers = [to_sparse_tracer(t) if t._is_tracer else t for t in tracers] + # tracers = [to_sparse_tracer(t, any_tracer) if t._is_tracer else t for t in tracers] raise NotImplementedError - if any(isinstance(t, GatherLinTracer) for t in tracers): + elif any(isinstance(t, GatherLinTracer) for t in tracers): tracers = [to_gather_tracer(t) if t._is_tracer else t for t in tracers] any_tracer = [t for t in tracers if t._is_tracer][0] - o_dim = any_tracer._renamed.get(dim, dim) + src_dim = any_tracer._renamed.get(dim, dim) selection_dims = set(sum([channel(t._selection).item_names[0] for t in tracers if t._is_tracer and t._selection is not None], ())) + if not selection_dims: + selection_dims = {any_tracer._renamed.get(dim, dim)} selections = [] diags = [] biases = [] @@ -451,17 +457,36 @@ def concat_tracers(tracers: Sequence[Tensor], dim: str): biases.append(expand(t._bias, t.shape[dim])) else: # constant mapped_shape = rename_dims(t.shape, tuple(any_tracer._renamed), [any_tracer._source.shape[o] for o in any_tracer._renamed.values()]) - selections.append(math.zeros(instance(selection=mapped_shape.only(selection_dims).volume), channel(gather=list(selection_dims)), dtype=(int, 32))) + selections.append(math.zeros(instance(selection=mapped_shape.only(src_dim).volume), channel(gather=list(selection_dims)), dtype=(int, 32))) diags.append(math.zeros(t.shape[dim])) biases.append(expand(discard_constant_dims(t), t.shape[dim])) full_diag = concat_tensor(diags, dim) full_bias = math.concat(biases, dim, expand_values=True) - shape = merge_shapes([t.shape.with_dim_size(dim, full_bias.shape.get_size(dim)) for t in tracers]) full_selection = concat_tensor(selections, 'selection') full_selection = unpack_dim(full_selection, 'selection', shape[dim]) renamed = any_tracer._renamed + assert non_channel(full_selection).size == full_size return GatherLinTracer(tracers[0]._source, full_diag, full_bias, shape, full_selection, renamed) - raise NotImplementedError + else: # only ShiftLinTracer + constants + if True: # ToDo this block minimizes zeros but removes the ShiftLinTracer structure. May be slower in the long run... + tracers = [to_gather_tracer(t) if t._is_tracer else t for t in tracers] + return concat_tracers(tracers, dim) + # any_tracer = [t for t in tracers if t._is_tracer][0] + # biases = [] + # aligned = [] + # offset = 0 + # for t in tracers: + # if t._is_tracer: + # assert isinstance(t, ShiftLinTracer) + # t_aligned = t.shift({dim: offset}, shape, lambda v: math.pad(v, {dim: (offset, full_size - offset - t.shape.get_size(dim))}, 0), lambda b: b) + # aligned.append(t_aligned) + # biases.append(expand(t._bias, t.shape[dim])) + # else: # constant + # biases.append(expand(discard_constant_dims(t), t.shape[dim])) + # offset += t.shape.get_size(dim) + # full_tracer = sum(aligned[1:], aligned[0]) + # full_bias = math.concat(biases, dim, expand_values=True) + # return ShiftLinTracer(any_tracer._source, full_tracer.val, shape, full_bias, any_tracer._renamed) class LinearTraceInProgress(Exception): @@ -476,6 +501,7 @@ def matrix_from_function(f: Callable, auto_compress=False, sparsify_batch=None, separate_independent=False, # not fully implemented, requires auto_compress=False + _return_raw_output=False, **kwargs) -> Tuple[Tensor, Tensor]: """ Trace a linear function and construct a matrix. @@ -512,7 +538,7 @@ def matrix_from_function(f: Callable, tracer = ShiftLinTracer(src, {EMPTY_SHAPE: math.ones()}, tensors[0].shape, bias=math.zeros(dtype=tensors[0].dtype), renamed={d: d for d in tensors[0].shape.names}) x_kwargs = assemble_tree(tree, [tracer]) result = f(**x_kwargs, **aux_args) - _, result_tensors = disassemble_tree(result) + out_tree, result_tensors = disassemble_tree(result) assert len(result_tensors) == 1, f"Linear function output must be or contain a single Tensor but got {result}" tracer = result_tensors[0]._simplify() assert tracer._is_tracer, f"Tracing linear function '{f_name(f)}' failed. Make sure only linear operations are used. Output: {tracer.shape}" @@ -524,17 +550,16 @@ def matrix_from_function(f: Callable, sparsify_batch = not target_backend.supports(Backend.sparse_coo_tensor_batched) if isinstance(tracer, SparseLinTracer): matrix, bias = tracer._get_matrix(sparsify_batch), tracer._bias + elif isinstance(tracer, GatherLinTracer): + matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify_batch), tracer._bias else: matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent) # --- Compress --- - if not auto_compress: - return matrix, bias - if matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor): - return matrix.compress_rows(), bias + if auto_compress and matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor): + matrix = matrix.compress_rows() # elif backend.supports(Backend.mul_csc_dense): # return matrix.compress_cols(), tracer._bias - else: - return matrix, bias + return (matrix, bias, (out_tree, result_tensors)) if _return_raw_output else (matrix, bias) def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bool): # ToDo this may return compressed if function uses @@ -623,7 +648,11 @@ def dependent_src_dims(tracer: Tensor) -> Shape: assert len(result) == len(names) return result elif isinstance(tracer, GatherLinTracer): - return tracer._source.shape.only(channel(tracer._selection).item_names[0]) + dims = set() + if tracer._selection is not None: + dims.update(set(channel(tracer._selection).item_names[0])) + dims.update(set([tracer._renamed.get(d, d) for d in tracer._bias.shape.names + tracer._diag.shape.names])) + return tracer._source.shape.only(dims) elif isinstance(tracer, SparseLinTracer): return tracer._source.shape.only(sparse_dims(tracer._matrix).names) @@ -641,9 +670,12 @@ def dependent_out_dims(tracer: Tensor) -> Shape: assert len(result) == len(names) return result elif isinstance(tracer, GatherLinTracer): - return tracer._selection.shape.non_channel + result = tracer._bias.shape & tracer._diag.shape + if tracer._selection is not None: + result &= tracer._selection.shape.non_channel + return result elif isinstance(tracer, SparseLinTracer): - raise NotImplementedError + return tracer._matrix.sparse_dims.only(tracer.shape) def pattern_dim_names(tracer) -> Set[str]: @@ -660,21 +692,32 @@ def pattern_dim_names(tracer) -> Set[str]: # return set(dependent_src_dims(tracer).names) -def to_sparse_tracer(tracer: Tensor) -> SparseLinTracer: +def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer: assert tracer._is_tracer if isinstance(tracer, SparseLinTracer): return tracer if isinstance(tracer, ShiftLinTracer): - tracer = tracer._to_gather_tracer() + tracer = to_gather_tracer(tracer) assert isinstance(tracer, GatherLinTracer) - in_dims = dependent_src_dims(tracer).as_dual() - out_dims = dependent_out_dims(tracer) - cols = rename_dims(tracer._selection, channel, channel(vector=in_dims)) - gather_dims = tracer._selection.shape.non_channel + if tracer._selection is None: + if ref is not None: + in_dims = dependent_src_dims(ref).as_dual() + out_dims = dependent_out_dims(ref) + cols = math.meshgrid(out_dims.as_instance()) + else: + in_dims = dependent_src_dims(tracer).as_dual() + out_dims = dependent_out_dims(tracer) + cols = math.meshgrid(out_dims.as_instance()) + else: + in_dims = dependent_src_dims(tracer).as_dual() + out_dims = dependent_out_dims(tracer) + cols = rename_dims(tracer._selection, non_channel, instance) + cols = rename_dims(cols, channel, channel(vector=in_dims)) + gather_dims = cols.shape.non_channel rows = math.meshgrid(gather_dims, stack_dim=channel(vector=out_dims)) indices = concat_tensor([rows, cols], 'vector') dense_shape = in_dims & out_dims - matrix = sparse_tensor(indices, tracer._diag, dense_shape, can_contain_double_entries=False, indices_sorted=False, format='coo', default=0) + matrix = sparse_tensor(indices, rename_dims(tracer._diag, indices.shape.non_channel.non_batch, instance), dense_shape, can_contain_double_entries=False, indices_sorted=False, format='coo', default=0) # ToDo check renaming return SparseLinTracer(tracer._source, matrix, tracer._bias, tracer._shape)