From df30015683c6fc358b172341aa2ca662ce3fede7 Mon Sep 17 00:00:00 2001 From: holl- Date: Thu, 30 Nov 2023 12:49:20 +0100 Subject: [PATCH] Support trees in all_available() --- phiml/math/_ops.py | 11 ++++++----- phiml/math/_optimize.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index cad19d18..7d3cd072 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -102,10 +102,10 @@ def _to_device(value: Tensor or Any, device: ComputeDevice or str, convert_to_ba return device.backend.allocate_on_device(value, device) -def all_available(*values: Tensor) -> bool: +def all_available(*values) -> bool: """ - Tests if the values of all given tensors are known and can be read at this point. - Tracing placeholders are considered not available, even when they hold example values. + Tests if all tensors contained in the given `values` are currently known and can be read. + Placeholder tensors used to trace functions for just-in-time compilation or matrix construction are considered not available, even when they hold example values like with PyTorch's JIT. Tensors are not available during `jit_compile()`, `jit_compile_linear()` or while using TensorFlow's legacy graph mode. @@ -118,12 +118,13 @@ def all_available(*values: Tensor) -> bool: * Jax: `isinstance(x, jax.core.Tracer)` Args: - values: Tensors to check. + values: Tensors to check. Returns: `True` if no value is a placeholder or being traced, `False` otherwise. """ - return all([v.available for v in values]) + _, tensors = disassemble_tree(values) + return all([t.available for t in tensors]) def seed(seed: int): diff --git a/phiml/math/_optimize.py b/phiml/math/_optimize.py index 70ccf71e..ab4330f5 100644 --- a/phiml/math/_optimize.py +++ b/phiml/math/_optimize.py @@ -649,10 +649,10 @@ def _linear_solve_forward(y: Tensor, else: max_iter = reshaped_numpy(solve.max_iterations, [shape(solve.max_iterations).without(batch_dims), batch_dims]) method = solve.method - if not callable(native_lin_op) and is_sparse(native_lin_op) and y.default_backend.name == 'torch' and preconditioner and not y.available: + if not callable(native_lin_op) and is_sparse(native_lin_op) and y.default_backend.name == 'torch' and preconditioner and not all_available(y): warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning) preconditioner = None - if not callable(native_lin_op) and is_sparse(native_lin_op) and not y.available and not method.startswith('scipy-') and isinstance(preconditioner, IncompleteLU): + if not callable(native_lin_op) and is_sparse(native_lin_op) and not all_available(y) and not method.startswith('scipy-') and isinstance(preconditioner, IncompleteLU): warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Using preconditioned scipy-{method} solve instead. If you want to use {y.default_backend}, please disable the preconditioner.", RuntimeWarning) method = 'scipy-' + method t = time.perf_counter()