Skip to content

Commit

Permalink
Fix linear trace (internal conversion)
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 30, 2024
1 parent bbe7947 commit e2c2ace
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def __init__(self, source: TracerSource, matrix: SparseCoordinateTensor, bias: T
def __repr__(self):
return f"{self.__class__.__name__} {self._shape}"

def _get_matrix(self, sparsify_batch: bool):
def _get_matrix(self, sparsify: Shape):
in_dims = [d for d in self._matrix.shape if d.name.endswith('_src')]
renamed = [d.name[:-4] for d in self._matrix.shape if d.name.endswith('_src')]
return rename_dims(self._matrix, in_dims, renamed)
Expand Down Expand Up @@ -731,12 +731,13 @@ def matrix_from_function(f: Callable,
sparsify_batch = not target_backend.supports(Backend.csr_matrix_batched)
else:
sparsify_batch = not target_backend.supports(Backend.sparse_coo_tensor_batched)
sparsify = tracer.shape if sparsify_batch else EMPTY_SHAPE
if isinstance(tracer, SparseLinTracer):
matrix, bias = tracer._get_matrix(sparsify_batch), tracer._bias
matrix, bias = tracer._get_matrix(sparsify), tracer._bias
elif isinstance(tracer, GatherLinTracer):
matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify_batch), tracer._bias
matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify), tracer._bias
else:
matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent)
matrix, bias = tracer_to_coo(tracer, sparsify, separate_independent)
# --- Compress ---
if auto_compress and matrix.backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor):
matrix = matrix.compress_rows()
Expand All @@ -745,12 +746,9 @@ def matrix_from_function(f: Callable,
return (matrix, bias, (out_tree, result_tensors)) if _return_raw_output else (matrix, bias)


def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bool): # ToDo this may return compressed if function uses
# if isinstance(tracer, CollapsedTensor):
# tracer = tracer._cached if tracer.is_cached else tracer._inner # ignore collapsed dimensions. Alternatively, we could expand the result
# return tracer_to_coo(tracer, sparsify_batch, separate_independent)
def tracer_to_coo(tracer: Tensor, sparsify: Shape, separate_independent: bool):
if isinstance(tracer, TensorStack): # This indicates separable solves
matrices, biases = zip(*[tracer_to_coo(t, sparsify_batch, separate_independent) for t in tracer._tensors])
matrices, biases = zip(*[tracer_to_coo(t, sparsify-tracer._stack_dim, separate_independent) for t in tracer._tensors])
bias = stack(biases, tracer._stack_dim)
if not separate_independent:
indices = [math.concat_tensor([m._indices, expand(i, instance(m._indices), channel(sparse_idx=tracer._stack_dim.name))], 'sparse_idx') for i, m in enumerate(matrices)]
Expand All @@ -767,7 +765,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
assert isinstance(tracer, ShiftLinTracer), f"Tracing linear function returned an unsupported construct: {type(tracer)}"
pattern_dims = tracer._source.shape.only(pattern_dim_names(tracer))
assert batch(pattern_dims).is_empty, f"Batch dimensions may not be sliced in linear operations but got pattern for {batch(pattern_dims)}"
out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify_batch)
out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify)
original_out_names = [tracer._out_name_to_original.get(d, d) for d in out_shape.names]
batch_val = merge_shapes(*tracer.val.values()).without(out_shape)
if non_batch(out_shape).is_empty:
Expand Down Expand Up @@ -808,10 +806,10 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
return matrix, tracer._bias


def matrix_dims_for_tracer(tracer: Union[ShiftLinTracer, SparseLinTracer], sparsify_batch: bool):
def matrix_dims_for_tracer(tracer: Union[ShiftLinTracer, SparseLinTracer], sparsify: Shape):
renamed_src_names = [o for n, o in tracer._out_name_to_original.items() if n != o]
removed_dims = tracer._source.shape.without(tracer.shape).without(renamed_src_names) # these were sliced off
ignored_dims = tracer._source.shape.without(dependent_out_dims(tracer) if sparsify_batch else pattern_dim_names(tracer)).without(removed_dims).without(renamed_src_names) # these will be parallelized and not added to the matrix
ignored_dims = tracer._source.shape - dependent_out_dims(tracer, sparsify) - removed_dims - renamed_src_names # these will be parallelized and not added to the matrix
out_shape = tracer.shape.without(ignored_dims)
typed_src_shape = tracer._source.shape.without(ignored_dims)
src_shape = typed_src_shape.as_dual()
Expand Down Expand Up @@ -844,7 +842,7 @@ def dependent_src_dims(tracer: Tensor) -> Shape:
return tracer._source.shape.only(sparse_dims(tracer._matrix).names)


def dependent_out_dims(tracer: Tensor) -> Shape:
def dependent_out_dims(tracer: Tensor, sparsify=None) -> Shape:
"""
Current dimensions relevant to the linear operation.
This includes `pattern_dims` as well as dimensions along which only the values vary.
Expand All @@ -854,8 +852,13 @@ def dependent_out_dims(tracer: Tensor) -> Shape:
They are not included unless also relevant to the matrix.
"""
if isinstance(tracer, ShiftLinTracer):
bias_dims = set(variable_shape(tracer._bias).names)
names = pattern_dim_names(tracer) | set(sum([t.shape.names for t in tracer.val.values()], ())) | bias_dims
bias_names = set(variable_shape(tracer._bias).names)
pattern_names = pattern_dim_names(tracer)
if sparsify is None:
value_names = set(sum([t.shape.names for t in tracer.val.values()], ()))
else:
value_names = set([n for t in tracer.val.values() for n in t.shape.names if n in sparsify])
names = bias_names | pattern_names | value_names
result = tracer.shape.only(names)
assert len(result) == len(names), f"Tracer was modified along {names} but the dimensions {names - set(result.names)} are not present anymore, probably due to slicing. Make sure the linear function output retains all dimensions relevant to the linear operation."
return result
Expand Down Expand Up @@ -887,7 +890,7 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer:
if isinstance(tracer, SparseLinTracer):
return tracer
if isinstance(tracer, ShiftLinTracer):
matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False)
matrix, bias = tracer_to_coo(tracer, sparsify=dependent_out_dims(ref), separate_independent=False)
src_dims = dual(matrix) - set(tracer._renamed)
matrix = rename_dims(matrix, src_dims, [f'~{n}_src' for n in src_dims.as_batch().names])
return SparseLinTracer(tracer._source, matrix, bias, tracer.shape)
Expand Down

0 comments on commit e2c2ace

Please sign in to comment.