diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 39ffb5bf..0c8cb206 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -2314,7 +2314,7 @@ def gather(values: Tensor, indices: Tensor, dims: Union[DimFilter, None] = None) indices = expand(indices, channel(gather=dims)) if not channel(indices).item_names[0]: indices = indices._with_shape_replaced(indices.shape.with_dim_size(channel(indices), dims)) - return values.to_sparse_tracer().gather(indices) + return values._gather(indices) treat_as_batch = non_channel(indices).only(values.shape).without(dims) batch_ = (values.shape.batch & indices.shape.batch).without(dims) & treat_as_batch channel_ = values.shape.without(dims).without(batch_) diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index d74d00c0..89b7d944 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -9,7 +9,7 @@ from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal from ._magic_ops import stack, expand, rename_dims from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, discard_constant_dims -from ._sparse import SparseCoordinateTensor, is_sparse, sparse_dims, same_sparsity_pattern +from ._sparse import SparseCoordinateTensor, is_sparse, sparse_dims, same_sparsity_pattern, sparse_tensor from . import _ops as math @@ -38,11 +38,11 @@ def __init__(self, source: TracerSource, values_by_shift: dict, shape: Shape, bi """ assert isinstance(source, TracerSource) assert isinstance(renamed, dict) - self.source = source + self._source = source self.val: Dict[Shape, Tensor] = simplify_add(values_by_shift) for shift_ in self.val.keys(): assert shift_.only(sorted(shift_.names), reorder=True) == shift_ - self.bias = bias + self._bias = bias self._shape = shape self._renamed = renamed # new_name -> old_name @@ -50,34 +50,12 @@ def __repr__(self): return f"Linear tracer {self._shape}" @property - def dependent_dims(self) -> Set[str]: - """ - Dimensions relevant to the linear operation. - This includes `pattern_dims` as well as dimensions along which only the values vary. - These dimensions cannot be parallelized trivially with a non-batched matrix. - """ - bias_dims = [dim for dim in self.bias.shape.names if may_vary_along(self.bias, dim)] - return self.pattern_dim_names | set(sum([t.shape.names for t in self.val.values()], ())) | set(bias_dims) - - @property - def pattern_dim_names(self) -> Set[str]: - """ - Dimensions along which the sparse matrix contains off-diagonal elements. - These dimensions must be part of the sparse matrix and cannot be parallelized. - """ - return set(sum([offset.names for offset in self.val], ())) - - @property - def pattern_dims(self) -> Shape: - return self.source.shape.only(self.pattern_dim_names) - - @property - def out_name_to_original(self) -> Dict[str, str]: + def _out_name_to_original(self) -> Dict[str, str]: return self._renamed @property def dtype(self): - return self.source.dtype + return self._source.dtype @property def shape(self): @@ -85,8 +63,8 @@ def shape(self): def _with_shape_replaced(self, new_shape: Shape): renamed = {new_dim: self._renamed[old_dim] for old_dim, new_dim in zip(self._shape.names, new_shape.names)} - bias = rename_dims(self.bias, self._shape, new_shape) - return ShiftLinTracer(self.source, self.val, new_shape, bias, renamed) + bias = rename_dims(self._bias, self._shape, new_shape) + return ShiftLinTracer(self._source, self.val, new_shape, bias, renamed) @property def _is_tracer(self) -> bool: @@ -123,19 +101,19 @@ def shift(self, shifts: dict, if delta: shift = shift._replace_single_size(dim, shift.get_size(dim) + delta) if dim in shift else shift._expand(spatial(**{dim: delta})) val[shift.only(sorted(shift.names), reorder=True)] = val_fun(values) - bias = bias_fun(self.bias) - return ShiftLinTracer(self.source, val, new_shape, bias, self._renamed) + bias = bias_fun(self._bias) + return ShiftLinTracer(self._source, val, new_shape, bias, self._renamed) def _unstack(self, dimension): raise NotImplementedError() def __neg__(self): - return ShiftLinTracer(self.source, {shift: -values for shift, values in self.val.items()}, self._shape, -self.bias, self._renamed) + return ShiftLinTracer(self._source, {shift: -values for shift, values in self.val.items()}, self._shape, -self._bias, self._renamed) def _op1(self, native_function): # __neg__ is the only proper linear op1 and is implemented above. if native_function.__name__ == 'isfinite': - test_output = self.apply(math.ones(self.source.shape, dtype=self.source.dtype)) + test_output = self.apply(math.ones(self._source.shape, dtype=self._source.dtype)) return math.is_finite(test_output) else: raise NotImplementedError('Only linear operations are supported') @@ -153,7 +131,7 @@ def _op2(self, other: Tensor, zeros_for_missing_other = op_name not in ['add', 'radd', 'sub'] # perform `operator` where `other == 0` if isinstance(other, ShiftLinTracer): - assert self.source is other.source, "Multiple linear tracers are not yet supported." + assert self._source is other._source, "Multiple linear tracers are not yet supported." assert set(self._shape) == set(other._shape), f"Tracers have different shapes: {self._shape} and {other._shape}" values = {} for dim_shift in self.val.keys(): @@ -170,19 +148,19 @@ def _op2(self, other: Tensor, values[dim_shift] = operator(math.zeros_like(other_values), other_values) else: values[dim_shift] = other_values - bias = operator(self.bias, other.bias) - return ShiftLinTracer(self.source, values, self._shape, bias, self._renamed) + bias = operator(self._bias, other._bias) + return ShiftLinTracer(self._source, values, self._shape, bias, self._renamed) else: other = self._tensor(other) if op_symbol in '*/': values = {} for dim_shift, val in self.val.items(): values[dim_shift] = operator(val, other) - bias = operator(self.bias, other) - return ShiftLinTracer(self.source, values, self._shape & other.shape, bias, self._renamed) + bias = operator(self._bias, other) + return ShiftLinTracer(self._source, values, self._shape & other.shape, bias, self._renamed) elif op_symbol in '+-': - bias = operator(self.bias, other) - return ShiftLinTracer(self.source, self.val, self._shape & other.shape, bias, self._renamed) + bias = operator(self._bias, other) + return ShiftLinTracer(self._source, self.val, self._shape & other.shape, bias, self._renamed) else: raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") @@ -190,21 +168,19 @@ def _natives(self) -> tuple: """ This function should only be used to determine the compatible backends, this tensor should be regarded as not available. """ - return sum([v._natives() for v in self.val.values()], ()) + self.bias._natives() + return sum([v._natives() for v in self.val.values()], ()) + self._bias._natives() def _spec_dict(self) -> dict: raise LinearTraceInProgress(self) - def to_sparse_tracer(self) -> 'SparseLinTracer': - if len(self.val) > 1 or next(iter(self.val)): - raise NotImplementedError(f"Converting off-diagonal elements to sparse tracer not supported") - return SparseLinTracer(self.source, self.val[EMPTY_SHAPE], self.bias, self.shape, None, self._renamed) - - def matmul(self, matrix: Tensor, sdims: Shape, ddims: Shape): + def _matmul(self, matrix: Tensor, sdims: Shape, ddims: Shape): if is_sparse(matrix): - return self.to_sparse_tracer().matmul(matrix, sdims, ddims) + return to_gather_tracer(self).matmul(matrix, sdims, ddims) raise NotImplementedError + def _gather(self, indices: Tensor): + return to_gather_tracer(self)._gather(indices) + def simplify_add(val: dict) -> Dict[Shape, Tensor]: result = {} @@ -217,18 +193,138 @@ def simplify_add(val: dict) -> Dict[Shape, Tensor]: return result +class GatherLinTracer(Tensor): + """ + Represents the operation `source[selection] * diag + bias`. + """ + + def __init__(self, source: TracerSource, diag, bias: Tensor, shape: Shape, selection: Optional[Tensor], renamed: Dict[str, str]): + assert isinstance(diag, Tensor) + self._source = source + self._diag = diag # full matrix or diagonal elements only + self._bias = bias # matches self.shape + self._selection = selection # indices of the source tensor to be selected before multiplication by matrix. Can index one or multiple dimensions of the source. + self._shape = shape + self._renamed = renamed # dims renamed before matrix mul. new_name -> old_name + if selection is not None: + assert 'x' not in channel(selection).item_names[0] + assert bias.shape in self.shape + assert bias.shape.volume < 1000 # ToDo only for debug + assert bias.rank <= 1 # ToDo only for debug + + def __repr__(self): + return f"Sparse linear tracer {self._shape}" + + def _get_matrix(self, sparsify_batch: bool): + raise NotImplementedError + + def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape): + shape = matrix.shape.without(mdims) & self._shape.without(ddims) + matrix *= self._matrix + matrix = rename_dims(matrix, mdims, rename_dims(ddims, [*self._renamed.keys()], [*self._renamed.values()]).as_dual()) + renamed = {n: o for n, o in self._renamed.items() if n not in ddims} + return SparseLinTracer(self._source, matrix, self._bias, shape, self._selection, renamed) + + def _gather(self, indices: Tensor): + """ + Args: + indices: has 1 channel and 1 non-channel/non-instance + """ + dims = channel(indices).item_names[0] + shape = self.shape.without(dims) & indices.shape.non_channel + renamed = {n: o for n, o in self._renamed.items() if n not in dims} + bias = expand(self._bias, self.shape.only(dims))[indices] + if self._selection is not None: + indices = self._selection[indices] + old_sel_dims = [self._renamed.get(d, d) for d in channel(indices).item_names[0]] + indices_shape = indices.shape.with_dim_size(channel(indices), old_sel_dims) + indices = indices._with_shape_replaced(indices_shape) + return GatherLinTracer(self._source, self._diag, bias, shape, indices, renamed) + + def __neg__(self): + return GatherLinTracer(self._source, -self._diag, -self._bias, self._shape, self._renamed) + + def _op1(self, native_function): + # __neg__ is the only proper linear op1 and is implemented above. + if native_function.__name__ == 'isfinite': + finite = math.is_finite(self._source) & math.all_(math.is_finite(self._diag), self._source.shape) + raise NotImplementedError + else: + raise NotImplementedError('Only linear operations are supported') + + def _op2(self, other: Tensor, + operator: Callable, + native_function: Callable, + op_name: str = 'unknown', + op_symbol: str = '?') -> Tensor: + assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}" + if isinstance(other, ShiftLinTracer): + other = other._to_gather_tracer() + if isinstance(other, GatherLinTracer): + assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" + if not math.always_close(self._selection, other._selection): + return to_sparse_tracer(self)._op2(to_sparse_tracer(other), operator, native_function, op_name, op_symbol) + diag = operator(self._diag, other._diag) + bias = operator(self._bias, other._bias) + return GatherLinTracer(self._source, diag, bias, self._shape, self._selection, self._renamed) + if isinstance(other, SparseLinTracer): + return NotImplemented + else: + other = self._tensor(other) + if op_symbol in '*/': + matrix = operator(self._diag, other) + bias = operator(self._bias, other) + return GatherLinTracer(self._source, matrix, bias, self._shape & other.shape, self._selection, self._renamed) + elif op_symbol in '+-': + bias = operator(self._bias, other) + return GatherLinTracer(self._source, self._matrix, bias, self._shape & other.shape, self._selection, self._renamed) + else: + raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") + + @property + def _is_tracer(self) -> bool: + return True + + @property + def shape(self): + return self._shape + + def _with_shape_replaced(self, new_shape: Shape): + renamed = dict(self._renamed) + renamed.update({n: self._renamed.get(o, o) for n, o in zip(new_shape.names, self._shape.names)}) + return GatherLinTracer(self._source, self._diag, self._bias, new_shape, renamed) + + @property + def _out_name_to_original(self) -> Dict[str, str]: + return self._renamed + + def _natives(self) -> tuple: + """ + This function should only be used to determine the compatible backends, this tensor should be regarded as not available. + """ + return self._matrix._natives() + + def _get_selection(self, selection_dims, list_dim: Shape = instance('selection'), index_dim: Shape = channel('gather')): + original_dims = [self._renamed.get(d, d) for d in selection_dims] + if self._selection is not None: + assert selection_dims == set(channel(self._selection).item_names[0]) + return rename_dims(self._selection, non_batch(self._selection).non_channel, list_dim) + else: + sel_src_shape = self._source.shape.only(original_dims) + return expand(math.range_tensor(list_dim.with_size(sel_src_shape.volume)), index_dim.with_size(sel_src_shape.names)) + + class SparseLinTracer(Tensor): - def __init__(self, source: TracerSource, matrix, bias: Tensor, shape: Shape, src_selection: Optional[Tensor], renamed: Dict[str, str]): + def __init__(self, source: TracerSource, matrix, bias: Tensor, shape: Shape, renamed: Dict[str, str]): assert isinstance(matrix, Tensor) - self.source = source + self._source = source self._matrix = matrix # full matrix or diagonal elements only - self.bias = bias # should always match self.shape - self._src_selection = src_selection # indices of the source tensor to be selected before multiplication by matrix. Can index one or multiple dimensions of the source. + self._bias = bias # should always match self.shape self._shape = shape self._renamed = renamed # dims renamed before matrix mul. new_name -> old_name - if src_selection is not None: - assert 'x' not in channel(src_selection).item_names[0] + if selection is not None: + assert 'x' not in channel(selection).item_names[0] assert bias.shape in self.shape assert bias.shape.volume < 1000 # ToDo only for debug assert bias.rank <= 1 # ToDo only for debug @@ -236,22 +332,22 @@ def __init__(self, source: TracerSource, matrix, bias: Tensor, shape: Shape, src def __repr__(self): return f"Sparse linear tracer {self._shape}" - def get_matrix(self, sparsify_batch: bool): + def _get_matrix(self, sparsify_batch: bool): if not is_sparse(self._matrix): raise NotImplementedError # ToDo build diagonal matrix matrix = rename_dims(self._matrix, primal, self._shape) return matrix - def matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape): + def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape): if not is_sparse(self._matrix): shape = matrix.shape.without(mdims) & self._shape.without(ddims) matrix *= self._matrix matrix = rename_dims(matrix, mdims, rename_dims(ddims, [*self._renamed.keys()], [*self._renamed.values()]).as_dual()) renamed = {n: o for n, o in self._renamed.items() if n not in ddims} - return SparseLinTracer(self.source, matrix, self.bias, shape, self._src_selection, renamed) + return SparseLinTracer(self._source, matrix, self._bias, shape, self._selection, renamed) raise NotImplementedError - def gather(self, indices: Tensor): + def _gather(self, indices: Tensor): """ Args: indices: has 1 channel and 1 non-channel/non-instance @@ -259,21 +355,21 @@ def gather(self, indices: Tensor): dims = channel(indices).item_names[0] shape = self.shape.without(dims) & indices.shape.non_channel renamed = {n: o for n, o in self._renamed.items() if n not in dims} - bias = expand(self.bias, self.shape.only(dims))[indices] - if self._src_selection is not None: - indices = self._src_selection[indices] + bias = expand(self._bias, self.shape.only(dims))[indices] + if self._selection is not None: + indices = self._selection[indices] old_sel_dims = [self._renamed.get(d, d) for d in channel(indices).item_names[0]] indices_shape = indices.shape.with_dim_size(channel(indices), old_sel_dims) indices = indices._with_shape_replaced(indices_shape) - return SparseLinTracer(self.source, self._matrix, bias, shape, indices, renamed) + return SparseLinTracer(self._source, self._matrix, bias, shape, indices, renamed) def __neg__(self): - return SparseLinTracer(self.source, -self._matrix, -self.bias, self._shape, self._renamed) + return SparseLinTracer(self._source, -self._matrix, -self._bias, self._shape, self._renamed) def _op1(self, native_function): # __neg__ is the only proper linear op1 and is implemented above. if native_function.__name__ == 'isfinite': - finite = math.is_finite(self.source) & math.all_(math.is_finite(self._matrix), self.source.shape) + finite = math.is_finite(self._source) & math.all_(math.is_finite(self._matrix), self._source.shape) raise NotImplementedError else: raise NotImplementedError('Only linear operations are supported') @@ -293,17 +389,17 @@ def _op2(self, other: Tensor, raise NotImplementedError("Tracing consecutive sparse matrix multiplications not yet supported") assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" matrix = operator(self._matrix, other._matrix) - bias = operator(self.bias, other.bias) - return SparseLinTracer(self.source, matrix, bias, self._shape, self._renamed) + bias = operator(self._bias, other._bias) + return SparseLinTracer(self._source, matrix, bias, self._shape, self._renamed) else: other = self._tensor(other) if op_symbol in '*/': matrix = operator(self._matrix, other) - bias = operator(self.bias, other) - return SparseLinTracer(self.source, matrix, bias, self._shape & other.shape, self._renamed) + bias = operator(self._bias, other) + return SparseLinTracer(self._source, matrix, bias, self._shape & other.shape, self._renamed) elif op_symbol in '+-': - bias = operator(self.bias, other) - return SparseLinTracer(self.source, self._matrix, bias, self._shape & other.shape, self._renamed) + bias = operator(self._bias, other) + return SparseLinTracer(self._source, self._matrix, bias, self._shape & other.shape, self._renamed) else: raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") @@ -311,9 +407,6 @@ def _op2(self, other: Tensor, def _is_tracer(self) -> bool: return True - def to_sparse_tracer(self) -> 'SparseLinTracer': - return self - @property def shape(self): return self._shape @@ -321,27 +414,10 @@ def shape(self): def _with_shape_replaced(self, new_shape: Shape): renamed = dict(self._renamed) renamed.update({n: self._renamed.get(o, o) for n, o in zip(new_shape.names, self._shape.names)}) - return SparseLinTracer(self.source, self._matrix, self.bias, new_shape, renamed) + return SparseLinTracer(self._source, self._matrix, self._bias, new_shape, renamed) @property - def dependent_dims(self) -> Set[str]: - """ - Dimensions relevant to the linear operation. - This includes `pattern_dims` as well as dimensions along which only the values vary. - These dimensions cannot be parallelized trivially with a non-batched matrix. - """ - return set(sparse_dims(self._matrix).names) - - @property - def pattern_dim_names(self) -> Set[str]: - """ - Dimensions along which the sparse matrix contains off-diagonal elements. - These dimensions must be part of the sparse matrix and cannot be parallelized. - """ - return self.dependent_dims - - @property - def out_name_to_original(self) -> Dict[str, str]: + def _out_name_to_original(self) -> Dict[str, str]: return self._renamed def _natives(self) -> tuple: @@ -350,32 +426,35 @@ def _natives(self) -> tuple: """ return self._matrix._natives() - def get_src_selection(self, selection_dims, list_dim: Shape = instance('selection'), index_dim: Shape = channel('gather')): + def _get_selection(self, selection_dims, list_dim: Shape = instance('selection'), index_dim: Shape = channel('_gather')): original_dims = [self._renamed.get(d, d) for d in selection_dims] - if self._src_selection is not None: - assert selection_dims == set(channel(self._src_selection).item_names[0]) - return rename_dims(self._src_selection, non_batch(self._src_selection).non_channel, list_dim) + if self._selection is not None: + assert selection_dims == set(channel(self._selection).item_names[0]) + return rename_dims(self._selection, non_batch(self._selection).non_channel, list_dim) else: - sel_src_shape = self.source.shape.only(original_dims) + sel_src_shape = self._source.shape.only(original_dims) return expand(math.range_tensor(list_dim.with_size(sel_src_shape.volume)), index_dim.with_size(sel_src_shape.names)) def concat_tracers(tracers: Sequence[Tensor], dim: str): - tracers = [t.to_sparse_tracer() if t._is_tracer else t for t in tracers] - any_tracer = [t for t in tracers if t._is_tracer][0] - o_dim = any_tracer._renamed.get(dim, dim) - if all(not t._is_tracer or not is_sparse(t._matrix) for t in tracers): - selection_dims = set(sum([channel(t._src_selection).item_names[0] for t in tracers if t._is_tracer and t._src_selection is not None], ())) + if any(isinstance(t, SparseLinTracer) for t in tracers): + tracers = [to_sparse_tracer(t) if t._is_tracer else t for t in tracers] + raise NotImplementedError + if any(isinstance(t, GatherLinTracer) for t in tracers): + tracers = [to_gather_tracer(t) if t._is_tracer else t for t in tracers] + any_tracer = [t for t in tracers if t._is_tracer][0] + o_dim = any_tracer._renamed.get(dim, dim) + selection_dims = set(sum([channel(t._selection).item_names[0] for t in tracers if t._is_tracer and t._selection is not None], ())) selections = [] diags = [] biases = [] for t in tracers: - if isinstance(t, SparseLinTracer): - selections.append(t.get_src_selection(selection_dims)) - diags.append(expand(t._matrix, t.shape[dim])) - biases.append(expand(t.bias, t.shape[dim])) + if t._is_tracer: + selections.append(t._get_selection(selection_dims)) + diags.append(expand(t._diag, t.shape[dim])) + biases.append(expand(t._bias, t.shape[dim])) else: # constant - mapped_shape = rename_dims(t.shape, tuple(any_tracer._renamed), [any_tracer.source.shape[o] for o in any_tracer._renamed.values()]) + mapped_shape = rename_dims(t.shape, tuple(any_tracer._renamed), [any_tracer._source.shape[o] for o in any_tracer._renamed.values()]) selections.append(math.zeros(instance(selection=mapped_shape.only(selection_dims).volume), channel(gather=list(selection_dims)))) diags.append(math.zeros(t.shape[dim])) biases.append(expand(discard_constant_dims(t), t.shape[dim])) @@ -384,7 +463,7 @@ def concat_tracers(tracers: Sequence[Tensor], dim: str): full_bias = math.concat(biases, dim, expand_values=True) shape = merge_shapes([t.shape.with_dim_size(dim, full_bias.shape.get_size(dim)) for t in tracers]) renamed = any_tracer._renamed - return SparseLinTracer(tracers[0].source, full_diag, full_bias, shape, full_selection, renamed) + return GatherLinTracer(tracers[0]._source, full_diag, full_bias, shape, full_selection, renamed) raise NotImplementedError @@ -447,7 +526,7 @@ def matrix_from_function(f: Callable, else: sparsify_batch = not target_backend.supports(Backend.sparse_coo_tensor_batched) if isinstance(tracer, SparseLinTracer): - matrix, bias = tracer.get_matrix(sparsify_batch), tracer.bias + matrix, bias = tracer._get_matrix(sparsify_batch), tracer._bias else: matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent) # --- Compress --- @@ -456,7 +535,7 @@ def matrix_from_function(f: Callable, if matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor): return matrix.compress_rows(), bias # elif backend.supports(Backend.mul_csc_dense): - # return matrix.compress_cols(), tracer.bias + # return matrix.compress_cols(), tracer._bias else: return matrix, bias @@ -481,13 +560,14 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo elif not tracer._is_tracer: # This part of the output is independent of the input return expand(0, tracer.shape), tracer assert isinstance(tracer, ShiftLinTracer), f"Tracing linear function returned an unsupported construct: {type(tracer)}" - assert batch(tracer.pattern_dims).is_empty, f"Batch dimensions may not be sliced in linear operations but got pattern for {batch(tracer.pattern_dims)}" + pattern_dims = tracer._source.shape.only(tracer.pattern_dim_names) + assert batch(pattern_dims).is_empty, f"Batch dimensions may not be sliced in linear operations but got pattern for {batch(pattern_dims)}" out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify_batch) - out_shape_original = rename_dims(out_shape, [*tracer.out_name_to_original.keys()], [*tracer.out_name_to_original.values()]) + out_shape_original = rename_dims(out_shape, [*tracer._out_name_to_original.keys()], [*tracer._out_name_to_original.values()]) batch_val = merge_shapes(*tracer.val.values()).without(out_shape) if non_batch(out_shape).is_empty: assert len(tracer.val) == 1 and non_batch(tracer.val[EMPTY_SHAPE]) == EMPTY_SHAPE - return tracer.val[EMPTY_SHAPE], tracer.bias + return tracer.val[EMPTY_SHAPE], tracer._bias out_indices = [] src_indices = [] values = [] @@ -518,15 +598,69 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo values = math.reshaped_tensor(backend.concat(values, axis=-1), [batch_val, instance('entries')], convert=False) dense_shape = concat_shapes((sliced_src_shape if separate_independent else src_shape) & out_shape) matrix = SparseCoordinateTensor(indices, values, dense_shape, can_contain_double_entries=False, indices_sorted=False, default=0) - return matrix, tracer.bias + return matrix, tracer._bias def matrix_dims_for_tracer(tracer: Union[ShiftLinTracer, SparseLinTracer], sparsify_batch: bool): - renamed_src_names = [o for n, o in tracer.out_name_to_original.items() if n != o] - removed_dims = tracer.source.shape.without(tracer.shape).without(renamed_src_names) # these were sliced off - ignored_dims = tracer.source.shape.without(tracer.shape.only(tracer.dependent_dims) if sparsify_batch else tracer.pattern_dim_names).without(removed_dims).without(renamed_src_names) # these will be parallelized and not added to the matrix + renamed_src_names = [o for n, o in tracer._out_name_to_original.items() if n != o] + removed_dims = tracer._source.shape.without(tracer.shape).without(renamed_src_names) # these were sliced off + ignored_dims = tracer._source.shape.without(tracer.shape.only(tracer.dependent_dims) if sparsify_batch else tracer.pattern_dim_names).without(removed_dims).without(renamed_src_names) # these will be parallelized and not added to the matrix out_shape = tracer.shape.without(ignored_dims) - typed_src_shape = tracer.source.shape.without(ignored_dims) + typed_src_shape = tracer._source.shape.without(ignored_dims) src_shape = typed_src_shape.as_dual() sliced_src_shape = src_shape.without(removed_dims.as_dual()) return out_shape, src_shape, typed_src_shape, removed_dims, sliced_src_shape + + +def dependent_dims(tracer: Tensor) -> Set[str]: + """ + Dimensions relevant to the linear operation. + This includes `pattern_dims` as well as dimensions along which only the values vary. + These dimensions cannot be parallelized trivially with a non-batched matrix. + """ + if isinstance(tracer, ShiftLinTracer): + bias_dims = [dim for dim in tracer._bias.shape.names if may_vary_along(tracer._bias, dim)] + return tracer.pattern_dim_names | set(sum([t.shape.names for t in tracer.val.values()], ())) | set(bias_dims) + elif isinstance(tracer, GatherLinTracer): + return set(channel(tracer._selection).item_names[0]) + elif isinstance(tracer, SparseLinTracer): + return set(sparse_dims(tracer._matrix).names) + + +def pattern_dim_names(tracer) -> Set[str]: + """ + Dimensions along which the sparse matrix contains off-diagonal elements. + These dimensions must be part of the sparse matrix and cannot be parallelized. + """ + if isinstance(tracer, ShiftLinTracer): + return set(sum([offset.names for offset in tracer.val], ())) + elif isinstance(tracer, GatherLinTracer): + return dependent_dims(tracer) + elif isinstance(tracer, SparseLinTracer): + return dependent_dims(tracer) + + +def to_sparse_tracer(tracer: Tensor) -> SparseLinTracer: + assert tracer._is_tracer + if isinstance(tracer, SparseLinTracer): + return tracer + if isinstance(tracer, ShiftLinTracer): + tracer = tracer._to_gather_tracer() + assert isinstance(tracer, GatherLinTracer) + rows = math.arange(...) + cols = tracer._selection + indices = stack({}, channel('index')) + dense_shape = ... + matrix = sparse_tensor(indices, tracer._diag, dense_shape, can_contain_double_entries=False, indices_sorted=False, format='coo', default=0) + return SparseLinTracer(tracer._source, matrix, tracer._bias, tracer._shape, tracer._renamed) + + +def to_gather_tracer(t: Tensor) -> GatherLinTracer: + if isinstance(t, GatherLinTracer): + return t + if isinstance(t, SparseLinTracer): + raise AssertionError + assert isinstance(t, ShiftLinTracer) + if len(t.val) > 1 or next(iter(t.val)): + raise NotImplementedError(f"Converting off-diagonal elements to sparse tracer not supported") + return GatherLinTracer(t._source, t.val[EMPTY_SHAPE], t._bias, t._shape, None, t._renamed)