diff --git a/phiml/_troubleshoot.py b/phiml/_troubleshoot.py index 1625ff16..96916baf 100644 --- a/phiml/_troubleshoot.py +++ b/phiml/_troubleshoot.py @@ -141,7 +141,7 @@ def plot_solves(): for i, result in enumerate(solves): assert isinstance(result, math.SolveInfo) from .math._tensors import disassemble_tree - _, (residual,) = disassemble_tree(result.residual) + _, (residual,) = disassemble_tree(result.residual, cache=False) residual_mse = math.mean(math.sqrt(math.sum(residual ** 2)), residual.shape.without('trajectory')) residual_mse_max = math.max(math.sqrt(math.sum(residual ** 2)), residual.shape.without('trajectory')) # residual_mean = math.mean(math.abs(residual), residual.shape.without('trajectory')) diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 8394c75a..df44fb5d 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -106,7 +106,7 @@ def key_from_args(args: tuple, kwargs: Dict[str, Any], parameters: Tuple[str, .. if param in kwargs: aux_kwargs[param] = kwargs[param] del kwargs[param] - tree, tensors = disassemble_tree(kwargs) + tree, tensors = disassemble_tree(kwargs, cache=cache) tracing = not math.all_available(*tensors) backend = math.choose_backend_t(*tensors) natives, shapes, specs = disassemble_tensors(tensors, expand=cache) @@ -173,7 +173,7 @@ def jit_f_native(*natives): in_tensors = assemble_tensors(natives, in_key.specs) kwargs = assemble_tree(in_key.tree, in_tensors) f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors - tree, out_tensors = disassemble_tree((f_output, self._extract_tensors)) + tree, out_tensors = disassemble_tree((f_output, self._extract_tensors), cache=True) result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing) finally: @@ -289,21 +289,6 @@ def __init__(self, f, auxiliary_args: Set[str], forget_traces: bool): self.matrices_and_biases: Dict[SignatureKey, Tuple[SparseCoordinateTensor, Tensor, Tuple]] = {} self.nl_jit = JitFunction(f, self.auxiliary_args, forget_traces) # for backends that do not support sparse matrices - # def _trace(self, in_key: SignatureKey, prefer_numpy: bool) -> 'ShiftLinTracer': - # assert in_key.shapes[0].is_uniform, f"math.jit_compile_linear() only supports uniform tensors for function input and output but input shape was {in_key.shapes[0]}" - # with NUMPY if prefer_numpy else in_key.backend: - # x = math.ones(in_key.shapes[0]) - # tracer = ShiftLinTracer(x, {EMPTY_SHAPE: math.ones()}, x.shape, math.zeros(x.shape)) - # _TRACING_JIT.append(self) - # x_kwargs = assemble_tree(in_key.tree, [tracer]) - # result = self.f(**x_kwargs, **in_key.auxiliary_kwargs) - # _, result_tensors = disassemble_tree(result) - # assert len(result_tensors) == 1, f"Linear function must return a single Tensor or tensor-like but got {result}" - # result_tensor = result_tensors[0] - # assert isinstance(result_tensor, ShiftLinTracer), f"Tracing linear function '{f_name(self.f)}' failed. Make sure only linear operations are used." - # assert _TRACING_JIT.pop(-1) is self - # return result_tensor - def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict): if not key.tracing and key in self.matrices_and_biases: return self.matrices_and_biases[key] @@ -502,7 +487,7 @@ def f_native(*natives): loss_shape = in_key.backend.staticshape(loss_native) assert len( loss_shape) == 0, f"Only scalar losses are allowed when returning a native tensor but {f_name(self.f)} returned {type(loss_native).__name__} of shape {loss_shape}. For higher-dimensional values, use Φ-ML tensors instead." - nest, out_tensors = disassemble_tree(result) + nest, out_tensors = disassemble_tree(result, cache=True) result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(f_native, nest, result_shapes, specs, in_key.backend, in_key.tracing) return loss_native, result_natives @@ -521,7 +506,7 @@ def __call__(self, *args, **kwargs): else: raise AssertionError(f"jacobian() not supported by {key.backend}.") wrt_tensors = self._track_wrt(kwargs) - wrt_natives = self._track_wrt_natives(wrt_tensors, disassemble_tree(kwargs)[1]) + wrt_natives = self._track_wrt_natives(wrt_tensors, disassemble_tree(kwargs, cache=True)[1]) if key not in self.traces: self.traces[key] = self._trace_grad(key, wrt_natives) native_result = self.traces[key](*natives) @@ -547,7 +532,7 @@ def __name__(self): def _track_wrt(self, kwargs: dict): wrt_tensors = [] for name, arg in kwargs.items(): - _, tensors = disassemble_tree(arg) + _, tensors = disassemble_tree(arg, cache=True) wrt_tensors.extend([name] * len(tensors)) return [t_i for t_i, name in enumerate(wrt_tensors) if name in self._wrt_tuple] @@ -666,7 +651,7 @@ def __init__(self, f: Callable, f_params, wrt: tuple, get_output: bool, get_grad # kwargs = assemble_tree(in_key.tree, in_tensors) # with functional_derivative_evaluation(order=2): # result = self.f(**kwargs) -# nest, out_tensors = disassemble_tree(result) +# nest, out_tensors = disassemble_tree(result, cache=True) # result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) # self.recorded_mappings[in_key] = SignatureKey(f_native, nest, result_shapes, specs, in_key.backend, in_key.tracing) # return result_natives @@ -805,7 +790,7 @@ def forward_native(*natives): kwargs = assemble_tree(in_key.tree, in_tensors) ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors") result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors - nest, out_tensors = disassemble_tree(result) + nest, out_tensors = disassemble_tree(result, cache=True) result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(forward_native, nest, result_shapes, specs, in_key.backend, in_key.tracing) return result_natives @@ -1020,7 +1005,7 @@ def map_types(f: Callable, dims: Union[Shape, tuple, list, str, Callable], dim_t """ def forward_retype(obj, input_types: Shape): - tree, tensors = disassemble_tree(obj) + tree, tensors = disassemble_tree(obj, cache=False) retyped = [] for t in tensors: for dim in t.shape.only(dims): @@ -1030,7 +1015,7 @@ def forward_retype(obj, input_types: Shape): return assemble_tree(tree, retyped), input_types def reverse_retype(obj, input_types: Shape): - tree, tensors = disassemble_tree(obj) + tree, tensors = disassemble_tree(obj, cache=False) retyped = [] for t in tensors: for dim in t.shape.only(input_types.names): diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 91fc1560..ab1e152d 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -123,7 +123,7 @@ def all_available(*values) -> bool: Returns: `True` if no value is a placeholder or being traced, `False` otherwise. """ - _, tensors = disassemble_tree(values) + _, tensors = disassemble_tree(values, cache=False) return all([t.available for t in tensors]) @@ -545,7 +545,7 @@ def zeros(*shape: Shape, dtype: Union[DType, tuple, type] = None) -> Tensor: def zeros_like(obj: Union[Tensor, PhiTreeNode]) -> Union[Tensor, PhiTreeNode]: """ Create a `Tensor` containing only `0.0` / `0` / `False` with the same shape and dtype as `obj`. """ - nest, values = disassemble_tree(obj) + nest, values = disassemble_tree(obj, cache=False) zeros_ = [] for val in values: val = wrap(val) @@ -2876,9 +2876,9 @@ def assert_close(*values, for other in values[1:]: _assert_close(values[0], other, rel_tolerance, abs_tolerance, msg, verbose) elif all(isinstance(v, PhiTreeNode) for v in values): - tree0, tensors0 = disassemble_tree(values[0]) + tree0, tensors0 = disassemble_tree(values[0], cache=False) for value in values[1:]: - tree, tensors_ = disassemble_tree(value) + tree, tensors_ = disassemble_tree(value, cache=False) assert tree0 == tree, f"Tree structures do not match: {tree0} and {tree}" for t0, t in zip(tensors0, tensors_): _assert_close(t0, t, rel_tolerance, abs_tolerance, msg, verbose) @@ -2974,7 +2974,7 @@ def stop_gradient(x): if isinstance(x, Tensor): return x._op1(lambda native: choose_backend(native).stop_gradient(native)) elif isinstance(x, PhiTreeNode): - nest, values = disassemble_tree(x) + nest, values = disassemble_tree(x, cache=False) new_values = [stop_gradient(v) for v in values] return assemble_tree(nest, new_values) else: diff --git a/phiml/math/_optimize.py b/phiml/math/_optimize.py index 5d4315b3..99fd2491 100644 --- a/phiml/math/_optimize.py +++ b/phiml/math/_optimize.py @@ -204,7 +204,7 @@ def _default_solve_info_msg(msg: str, converged: bool, diverged: bool, iteration if diverged: return f"Solve diverged within {iterations if iterations is not None else '?'} iterations using {method}." elif not converged: - max_res = [f"{math.max_(t.trajectory[-1]):no-color:no-dtype}" for t in disassemble_tree(residual)[1]] + max_res = [f"{math.max_(t.trajectory[-1]):no-color:no-dtype}" for t in disassemble_tree(residual, cache=False)[1]] return f"{method} did not converge to rel_tol={float(solve.rel_tol):.0e}, abs_tol={float(solve.abs_tol):.0e} within {int(solve.max_iterations)} iterations. Max residual: {', '.join(max_res)}" else: return f"Converged within {iterations if iterations is not None else '?'} iterations." @@ -364,7 +364,7 @@ def minimize(f: Callable[[X], Y], solve: Solve[X, Y]) -> X: solve = solve.with_defaults('optimization') assert (solve.rel_tol == 0).all, f"rel_tol must be zero for minimize() but got {solve.rel_tol}" assert solve.preprocess_y is None, "minimize() does not allow preprocess_y" - x0_nest, x0_tensors = disassemble_tree(solve.x0) + x0_nest, x0_tensors = disassemble_tree(solve.x0, cache=True) x0_tensors = [to_float(t) for t in x0_tensors] backend = choose_backend_t(*x0_tensors, prefer_default=True) batch_dims = merge_shapes(*[t.shape for t in x0_tensors]).batch @@ -408,7 +408,7 @@ def native_function(x_flat): y = f(*x) else: y = f(x) - _, y_tensors = disassemble_tree(y) + _, y_tensors = disassemble_tree(y, cache=False) assert not non_batch(y_tensors[0]), f"Failed to minimize '{f.__name__}' because it returned a non-scalar output {shape(y_tensors[0])}. Reduce all non-batch dimensions, e.g. using math.l2_loss()" if y_tensors[0].shape.without(batch_dims): # output added more batch dims. We should expand the initial guess raise NewBatchDims(y_tensors[0].shape, y_tensors[0].shape.without(batch_dims)) @@ -559,8 +559,8 @@ def solve_linear(f: Union[Callable[[X], Y], Tensor], f_kwargs.update(f_kwargs_) f_args = f_args[0] if len(f_args) == 1 and isinstance(f_args[0], tuple) else f_args # --- Get input and output tensors --- - y_tree, y_tensors = disassemble_tree(y) - x0_tree, x0_tensors = disassemble_tree(solve.x0) + y_tree, y_tensors = disassemble_tree(y, cache=False) + x0_tree, x0_tensors = disassemble_tree(solve.x0, cache=False) assert solve.x0 is not None, "Please specify the initial guess as Solve(..., x0=initial_guess)" assert len(x0_tensors) == len(y_tensors) == 1, "Only single-tensor linear solves are currently supported" if y_tree == 'native' and x0_tree == 'native': @@ -626,8 +626,8 @@ def _matrix_solve_forward(y, solve: Solve, matrix: Tensor, is_backprop=False): assert solve.preconditioner is None, f"Preconditioners not currently supported for matrix-free solves. Decorate '{f_name(f)}' with @math.jit_compile_linear to perform a matrix solve." def _function_solve_forward(y, solve: Solve, f_args: tuple, f_kwargs: dict = None, is_backprop=False): - y_nest, (y_tensor,) = disassemble_tree(y) - x0_nest, (x0_tensor,) = disassemble_tree(solve.x0) + y_nest, (y_tensor,) = disassemble_tree(y, cache=False) + x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False) # active_dims = (y_tensor.shape & x0_tensor.shape).non_batch # assumes batch dimensions are not active batches = (y_tensor.shape & x0_tensor.shape).batch @@ -636,7 +636,7 @@ def native_lin_f(native_x, batch_index=None): native_x = backend.tile(backend.expand_dims(native_x), [batches.volume, 1]) x = assemble_tree(x0_nest, [reshaped_tensor(native_x, [batches, non_batch(x0_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(x0_tensor)], convert=False)]) y = f(x, *f_args, **f_kwargs) - _, (y_tensor,) = disassemble_tree(y) + _, (y_tensor,) = disassemble_tree(y, cache=False) y_native = reshaped_native(y_tensor, [batches, non_batch(y_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(y_tensor)]) if batch_index is not None and batches.volume > 1: y_native = y_native[batch_index] @@ -661,8 +661,8 @@ def _linear_solve_forward(y: Tensor, ML_LOGGER.debug(f"Performing linear solve {solve} with backend {backend}") if solve.preprocess_y is not None: y = solve.preprocess_y(y, *solve.preprocess_y_args) - y_nest, (y_tensor,) = disassemble_tree(y) - x0_nest, (x0_tensor,) = disassemble_tree(solve.x0) + y_nest, (y_tensor,) = disassemble_tree(y, cache=False) + x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False) pattern_dims_in = x0_tensor.shape.only(pattern_dims_in, reorder=True) if pattern_dims_out not in y_tensor.shape: warnings.warn(f"right-hand-side has shape {y_tensor.shape} but output dimensions are {pattern_dims_out}. This may result in unexpected behavior", RuntimeWarning, stacklevel=3) @@ -729,8 +729,8 @@ def implicit_gradient_solve(fwd_args: dict, x, dx): if isinstance(matrix, SparseCoordinateTensor): col = matrix.dual_indices(to_primal=True) row = matrix.primal_indices() - _, dy_tensors = disassemble_tree(dy) - _, x_tensors = disassemble_tree(x) + _, dy_tensors = disassemble_tree(dy, cache=False) + _, x_tensors = disassemble_tree(x, cache=False) dm_values = dy_tensors[0][col] * x_tensors[0][row] dm = matrix._with_values(dm_values) dm = -dm diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index d259b922..be481c2e 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1862,7 +1862,7 @@ def assemble_tensors(natives: Union[tuple, list], specs: Union[Tuple[dict, ...], NATIVE_TENSOR = 'native' -def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor]]: +def disassemble_tree(obj: PhiTreeNodeType, cache: bool) -> Tuple[PhiTreeNodeType, List[Tensor]]: """ Splits a nested structure of Tensors into the structure without the tensors and an ordered list of tensors. Native tensors will be wrapped in phiml.math.Tensors with default dimension names and dimension types `None`. @@ -1873,6 +1873,7 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor Args: obj: Nested structure of `Tensor` objects. Nested structures include: `tuple`, `list`, `dict`, `phiml.math.magic.PhiTreeNode`. + cache: Whether to return cached versions of the tensors. This may reduce the number of native tensors required. Returns: empty structure: Same structure as `obj` but with the tensors replaced by `None`. @@ -1881,12 +1882,12 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor if obj is None: return MISSING_TENSOR, [] elif isinstance(obj, Tensor): - return None, [obj] + return None, [cached(obj) if cache else obj] elif isinstance(obj, (tuple, list)): keys = [] values = [] for item in obj: - key, value = disassemble_tree(item) + key, value = disassemble_tree(item, cache) keys.append(key) values.extend(value) return (tuple(keys) if isinstance(obj, tuple) else keys), values @@ -1894,7 +1895,7 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor keys = {} values = [] for name, item in obj.items(): - key, value = disassemble_tree(item) + key, value = disassemble_tree(item, cache) keys[name] = key values.extend(value) return keys, values @@ -1903,7 +1904,7 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor keys = {} values = [] for attr in attributes: - key, value = disassemble_tree(getattr(obj, attr)) + key, value = disassemble_tree(getattr(obj, attr), cache) keys[attr] = key values.extend(value) return copy_with(obj, **keys), values @@ -1967,9 +1968,8 @@ def cached(t: TensorOrTree) -> TensorOrTree: elif isinstance(t, Layout): return t elif isinstance(t, PhiTreeNode): - tree, tensors = disassemble_tree(t) - tensors_ = [cached(t_) for t_ in tensors] - return assemble_tree(tree, tensors_) + tree, tensors = disassemble_tree(t, cache=True) + return assemble_tree(tree, tensors) else: raise AssertionError(f"Cannot cache {type(t)} {t}") diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 1a274462..090579ce 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -593,7 +593,7 @@ def matrix_from_function(f: Callable, all_args = {**kwargs, **{f_params[i]: v for i, v in enumerate(args)}} aux_args = {k: v for k, v in all_args.items() if k in aux} trace_args = {k: v for k, v in all_args.items() if k not in aux} - tree, tensors = disassemble_tree(trace_args) + tree, tensors = disassemble_tree(trace_args, cache=False) target_backend = choose_backend_t(*tensors) # --- Trace function --- with NUMPY: @@ -601,7 +601,7 @@ def matrix_from_function(f: Callable, tracer = ShiftLinTracer(src, {EMPTY_SHAPE: math.ones()}, tensors[0].shape, bias=math.zeros(dtype=tensors[0].dtype), renamed={d: d for d in tensors[0].shape.names}) x_kwargs = assemble_tree(tree, [tracer]) result = f(**x_kwargs, **aux_args) - out_tree, result_tensors = disassemble_tree(result) + out_tree, result_tensors = disassemble_tree(result, cache=False) assert len(result_tensors) == 1, f"Linear function output must be or contain a single Tensor but got {result}" tracer = result_tensors[0]._simplify() assert tracer._is_tracer, f"Tracing linear function '{f_name(f)}' failed. Make sure only linear operations are used. Output: {tracer.shape}"