Skip to content

Commit

Permalink
Linear tracing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 11, 2023
1 parent 3ba1508 commit dd13a11
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,9 @@ def _op2(self, other: Tensor,
operator: Callable,
native_function: Callable,
op_name: str = 'unknown',
op_symbol: str = '?') -> 'ShiftLinTracer':
op_symbol: str = '?') -> Tensor:
if isinstance(other, SparseLinTracer):
sparse_self = self.to_sparse_tracer()
return sparse_self._op2(other, operator, native_function, op_name, op_symbol)
return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol)
assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}"
zeros_for_missing_self = op_name not in ['add', 'radd', 'rsub'] # perform `operator` where `self == 0`
zeros_for_missing_other = op_name not in ['add', 'radd', 'sub'] # perform `operator` where `other == 0`
Expand Down Expand Up @@ -246,7 +245,7 @@ def _gather(self, indices: Tensor):
return GatherLinTracer(self._source, diag, bias, shape, indices, renamed)

def _scatter(self, base: Tensor, indices: Tensor) -> Tensor:
return to_sparse_tracer(self)._scatter(base, indices)
return to_sparse_tracer(self, None)._scatter(base, indices)

def __neg__(self):
return GatherLinTracer(self._source, -self._diag, -self._bias, self._shape, self._selection, self._renamed)
Expand All @@ -270,7 +269,7 @@ def _op2(self, other: Tensor,
if isinstance(other, GatherLinTracer):
assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix"
if not math.always_close(self._selection, other._selection):
return to_sparse_tracer(self)._op2(other, operator, native_function, op_name, op_symbol)
return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol)
diag = operator(self._diag, other._diag)
bias = operator(self._bias, other._bias)
return GatherLinTracer(self._source, diag, bias, self._shape, self._selection, self._renamed)
Expand Down Expand Up @@ -339,9 +338,7 @@ def __repr__(self):
return f"{self.__class__.__name__} {self._shape}"

def _get_matrix(self, sparsify_batch: bool):
if sparsify_batch:
raise NotImplementedError
return self._matrix
return self._matrix # batch dims are currently always included in the matrix

def _matmul(self, matrix: Tensor, mdims: Shape, ddims: Shape):
raise NotImplementedError
Expand Down Expand Up @@ -386,7 +383,7 @@ def _op2(self, other,
other = self._tensor(other)
assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}"
if other._is_tracer and not isinstance(other, SparseLinTracer):
other = to_sparse_tracer(other)
other = to_sparse_tracer(other, self)
if isinstance(other, SparseLinTracer):
assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix"
bias = operator(self._bias, other._bias)
Expand Down Expand Up @@ -434,7 +431,7 @@ def _natives(self) -> tuple:

def concat_tracers(tracers: Sequence[Tensor], dim: str):
if any(isinstance(t, SparseLinTracer) for t in tracers):
tracers = [to_sparse_tracer(t) if t._is_tracer else t for t in tracers]
# tracers = [to_sparse_tracer(t, any_tracer) if t._is_tracer else t for t in tracers]
raise NotImplementedError
if any(isinstance(t, GatherLinTracer) for t in tracers):
tracers = [to_gather_tracer(t) if t._is_tracer else t for t in tracers]
Expand Down Expand Up @@ -476,6 +473,7 @@ def matrix_from_function(f: Callable,
auto_compress=False,
sparsify_batch=None,
separate_independent=False, # not fully implemented, requires auto_compress=False
_return_raw_output=False,
**kwargs) -> Tuple[Tensor, Tensor]:
"""
Trace a linear function and construct a matrix.
Expand Down Expand Up @@ -512,7 +510,7 @@ def matrix_from_function(f: Callable,
tracer = ShiftLinTracer(src, {EMPTY_SHAPE: math.ones()}, tensors[0].shape, bias=math.zeros(dtype=tensors[0].dtype), renamed={d: d for d in tensors[0].shape.names})
x_kwargs = assemble_tree(tree, [tracer])
result = f(**x_kwargs, **aux_args)
_, result_tensors = disassemble_tree(result)
out_tree, result_tensors = disassemble_tree(result)
assert len(result_tensors) == 1, f"Linear function output must be or contain a single Tensor but got {result}"
tracer = result_tensors[0]._simplify()
assert tracer._is_tracer, f"Tracing linear function '{f_name(f)}' failed. Make sure only linear operations are used. Output: {tracer.shape}"
Expand All @@ -527,14 +525,11 @@ def matrix_from_function(f: Callable,
else:
matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent)
# --- Compress ---
if not auto_compress:
return matrix, bias
if matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor):
return matrix.compress_rows(), bias
if auto_compress and matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor):
matrix = matrix.compress_rows()
# elif backend.supports(Backend.mul_csc_dense):
# return matrix.compress_cols(), tracer._bias
else:
return matrix, bias
return (matrix, bias, (out_tree, result_tensors)) if _return_raw_output else (matrix, bias)


def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bool): # ToDo this may return compressed if function uses
Expand Down Expand Up @@ -643,7 +638,7 @@ def dependent_out_dims(tracer: Tensor) -> Shape:
elif isinstance(tracer, GatherLinTracer):
return tracer._selection.shape.non_channel
elif isinstance(tracer, SparseLinTracer):
raise NotImplementedError
return tracer._matrix.sparse_dims.only(tracer.shape)


def pattern_dim_names(tracer) -> Set[str]:
Expand All @@ -660,17 +655,24 @@ def pattern_dim_names(tracer) -> Set[str]:
# return set(dependent_src_dims(tracer).names)


def to_sparse_tracer(tracer: Tensor) -> SparseLinTracer:
def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer:
assert tracer._is_tracer
if isinstance(tracer, SparseLinTracer):
return tracer
if isinstance(tracer, ShiftLinTracer):
tracer = tracer._to_gather_tracer()
tracer = to_gather_tracer(tracer)
assert isinstance(tracer, GatherLinTracer)
in_dims = dependent_src_dims(tracer).as_dual()
out_dims = dependent_out_dims(tracer)
cols = rename_dims(tracer._selection, channel, channel(vector=in_dims))
gather_dims = tracer._selection.shape.non_channel
if tracer._selection is None:
assert ref is not None, f"Insufficient tracer information"
in_dims = dependent_src_dims(ref).as_dual()
out_dims = dependent_out_dims(ref)
cols = math.meshgrid(in_dims.as_instance())
else:
in_dims = dependent_src_dims(tracer).as_dual()
out_dims = dependent_out_dims(tracer)
cols = tracer._selection
cols = rename_dims(cols, channel, channel(vector=in_dims))
gather_dims = cols.shape.non_channel
rows = math.meshgrid(gather_dims, stack_dim=channel(vector=out_dims))
indices = concat_tensor([rows, cols], 'vector')
dense_shape = in_dims & out_dims
Expand Down

0 comments on commit dd13a11

Please sign in to comment.