Skip to content

Commit

Permalink
Don't compress PyTorch matrices when requires_grad
Browse files Browse the repository at this point in the history
This is due to PyTorch not supporting grad of bincount
  • Loading branch information
holl- committed Jan 15, 2025
1 parent 14274b5 commit 5bf9d58
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ def matrix_from_function(f: Callable,
else:
matrix, bias = tracer_to_coo(tracer, sparsify, separate_independent)
# --- Compress ---
if matrix.backend.name == 'torch' and matrix._values._native.requires_grad:
auto_compress = False # PyTorch doesn't support gradient of bincount (used in compression)
if auto_compress and matrix.backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor):
matrix = matrix.compress_rows()
# elif backend.supports(Backend.mul_csc_dense):
Expand Down

0 comments on commit 5bf9d58

Please sign in to comment.