Skip to content

Commit

Permalink
Support trees in all_available()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 30, 2023
1 parent 0c04be9 commit df30015
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def _to_device(value: Tensor or Any, device: ComputeDevice or str, convert_to_ba
return device.backend.allocate_on_device(value, device)


def all_available(*values: Tensor) -> bool:
def all_available(*values) -> bool:
"""
Tests if the values of all given tensors are known and can be read at this point.
Tracing placeholders are considered not available, even when they hold example values.
Tests if all tensors contained in the given `values` are currently known and can be read.
Placeholder tensors used to trace functions for just-in-time compilation or matrix construction are considered not available, even when they hold example values like with PyTorch's JIT.
Tensors are not available during `jit_compile()`, `jit_compile_linear()` or while using TensorFlow's legacy graph mode.
Expand All @@ -118,12 +118,13 @@ def all_available(*values: Tensor) -> bool:
* Jax: `isinstance(x, jax.core.Tracer)`
Args:
values: Tensors to check.
values: Tensors to check.
Returns:
`True` if no value is a placeholder or being traced, `False` otherwise.
"""
return all([v.available for v in values])
_, tensors = disassemble_tree(values)
return all([t.available for t in tensors])


def seed(seed: int):
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,10 @@ def _linear_solve_forward(y: Tensor,
else:
max_iter = reshaped_numpy(solve.max_iterations, [shape(solve.max_iterations).without(batch_dims), batch_dims])
method = solve.method
if not callable(native_lin_op) and is_sparse(native_lin_op) and y.default_backend.name == 'torch' and preconditioner and not y.available:
if not callable(native_lin_op) and is_sparse(native_lin_op) and y.default_backend.name == 'torch' and preconditioner and not all_available(y):
warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning)
preconditioner = None
if not callable(native_lin_op) and is_sparse(native_lin_op) and not y.available and not method.startswith('scipy-') and isinstance(preconditioner, IncompleteLU):
if not callable(native_lin_op) and is_sparse(native_lin_op) and not all_available(y) and not method.startswith('scipy-') and isinstance(preconditioner, IncompleteLU):
warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Using preconditioned scipy-{method} solve instead. If you want to use {y.default_backend}, please disable the preconditioner.", RuntimeWarning)
method = 'scipy-' + method
t = time.perf_counter()
Expand Down

0 comments on commit df30015

Please sign in to comment.