diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 41ca6db..8249710 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -460,7 +460,7 @@ def __init__(self, source: TracerSource, matrix: SparseCoordinateTensor, bias: T def __repr__(self): return f"{self.__class__.__name__} {self._shape}" - def _get_matrix(self, sparsify_batch: bool): + def _get_matrix(self, sparsify: Shape): in_dims = [d for d in self._matrix.shape if d.name.endswith('_src')] renamed = [d.name[:-4] for d in self._matrix.shape if d.name.endswith('_src')] return rename_dims(self._matrix, in_dims, renamed) @@ -731,12 +731,13 @@ def matrix_from_function(f: Callable, sparsify_batch = not target_backend.supports(Backend.csr_matrix_batched) else: sparsify_batch = not target_backend.supports(Backend.sparse_coo_tensor_batched) + sparsify = tracer.shape if sparsify_batch else EMPTY_SHAPE if isinstance(tracer, SparseLinTracer): - matrix, bias = tracer._get_matrix(sparsify_batch), tracer._bias + matrix, bias = tracer._get_matrix(sparsify), tracer._bias elif isinstance(tracer, GatherLinTracer): - matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify_batch), tracer._bias + matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify), tracer._bias else: - matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent) + matrix, bias = tracer_to_coo(tracer, sparsify, separate_independent) # --- Compress --- if auto_compress and matrix.backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor): matrix = matrix.compress_rows() @@ -745,12 +746,9 @@ def matrix_from_function(f: Callable, return (matrix, bias, (out_tree, result_tensors)) if _return_raw_output else (matrix, bias) -def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bool): # ToDo this may return compressed if function uses - # if isinstance(tracer, CollapsedTensor): - # tracer = tracer._cached if tracer.is_cached else tracer._inner # ignore collapsed dimensions. Alternatively, we could expand the result - # return tracer_to_coo(tracer, sparsify_batch, separate_independent) +def tracer_to_coo(tracer: Tensor, sparsify: Shape, separate_independent: bool): if isinstance(tracer, TensorStack): # This indicates separable solves - matrices, biases = zip(*[tracer_to_coo(t, sparsify_batch, separate_independent) for t in tracer._tensors]) + matrices, biases = zip(*[tracer_to_coo(t, sparsify-tracer._stack_dim, separate_independent) for t in tracer._tensors]) bias = stack(biases, tracer._stack_dim) if not separate_independent: indices = [math.concat_tensor([m._indices, expand(i, instance(m._indices), channel(sparse_idx=tracer._stack_dim.name))], 'sparse_idx') for i, m in enumerate(matrices)] @@ -767,7 +765,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo assert isinstance(tracer, ShiftLinTracer), f"Tracing linear function returned an unsupported construct: {type(tracer)}" pattern_dims = tracer._source.shape.only(pattern_dim_names(tracer)) assert batch(pattern_dims).is_empty, f"Batch dimensions may not be sliced in linear operations but got pattern for {batch(pattern_dims)}" - out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify_batch) + out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify) original_out_names = [tracer._out_name_to_original.get(d, d) for d in out_shape.names] batch_val = merge_shapes(*tracer.val.values()).without(out_shape) if non_batch(out_shape).is_empty: @@ -808,10 +806,10 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo return matrix, tracer._bias -def matrix_dims_for_tracer(tracer: Union[ShiftLinTracer, SparseLinTracer], sparsify_batch: bool): +def matrix_dims_for_tracer(tracer: Union[ShiftLinTracer, SparseLinTracer], sparsify: Shape): renamed_src_names = [o for n, o in tracer._out_name_to_original.items() if n != o] removed_dims = tracer._source.shape.without(tracer.shape).without(renamed_src_names) # these were sliced off - ignored_dims = tracer._source.shape.without(dependent_out_dims(tracer) if sparsify_batch else pattern_dim_names(tracer)).without(removed_dims).without(renamed_src_names) # these will be parallelized and not added to the matrix + ignored_dims = tracer._source.shape - dependent_out_dims(tracer, sparsify) - removed_dims - renamed_src_names # these will be parallelized and not added to the matrix out_shape = tracer.shape.without(ignored_dims) typed_src_shape = tracer._source.shape.without(ignored_dims) src_shape = typed_src_shape.as_dual() @@ -844,7 +842,7 @@ def dependent_src_dims(tracer: Tensor) -> Shape: return tracer._source.shape.only(sparse_dims(tracer._matrix).names) -def dependent_out_dims(tracer: Tensor) -> Shape: +def dependent_out_dims(tracer: Tensor, sparsify=None) -> Shape: """ Current dimensions relevant to the linear operation. This includes `pattern_dims` as well as dimensions along which only the values vary. @@ -854,8 +852,13 @@ def dependent_out_dims(tracer: Tensor) -> Shape: They are not included unless also relevant to the matrix. """ if isinstance(tracer, ShiftLinTracer): - bias_dims = set(variable_shape(tracer._bias).names) - names = pattern_dim_names(tracer) | set(sum([t.shape.names for t in tracer.val.values()], ())) | bias_dims + bias_names = set(variable_shape(tracer._bias).names) + pattern_names = pattern_dim_names(tracer) + if sparsify is None: + value_names = set(sum([t.shape.names for t in tracer.val.values()], ())) + else: + value_names = set([n for t in tracer.val.values() for n in t.shape.names if n in sparsify]) + names = bias_names | pattern_names | value_names result = tracer.shape.only(names) assert len(result) == len(names), f"Tracer was modified along {names} but the dimensions {names - set(result.names)} are not present anymore, probably due to slicing. Make sure the linear function output retains all dimensions relevant to the linear operation." return result @@ -887,7 +890,7 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer: if isinstance(tracer, SparseLinTracer): return tracer if isinstance(tracer, ShiftLinTracer): - matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False) + matrix, bias = tracer_to_coo(tracer, sparsify=dependent_out_dims(ref), separate_independent=False) src_dims = dual(matrix) - set(tracer._renamed) matrix = rename_dims(matrix, src_dims, [f'~{n}_src' for n in src_dims.as_batch().names]) return SparseLinTracer(tracer._source, matrix, bias, tracer.shape)