Skip to content

Commit

Permalink
Add cache to disassemble_tree()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 14, 2024
1 parent bcca6bd commit 699571e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 52 deletions.
2 changes: 1 addition & 1 deletion phiml/_troubleshoot.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def plot_solves():
for i, result in enumerate(solves):
assert isinstance(result, math.SolveInfo)
from .math._tensors import disassemble_tree
_, (residual,) = disassemble_tree(result.residual)
_, (residual,) = disassemble_tree(result.residual, cache=False)
residual_mse = math.mean(math.sqrt(math.sum(residual ** 2)), residual.shape.without('trajectory'))
residual_mse_max = math.max(math.sqrt(math.sum(residual ** 2)), residual.shape.without('trajectory'))
# residual_mean = math.mean(math.abs(residual), residual.shape.without('trajectory'))
Expand Down
33 changes: 9 additions & 24 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def key_from_args(args: tuple, kwargs: Dict[str, Any], parameters: Tuple[str, ..
if param in kwargs:
aux_kwargs[param] = kwargs[param]
del kwargs[param]
tree, tensors = disassemble_tree(kwargs)
tree, tensors = disassemble_tree(kwargs, cache=cache)
tracing = not math.all_available(*tensors)
backend = math.choose_backend_t(*tensors)
natives, shapes, specs = disassemble_tensors(tensors, expand=cache)
Expand Down Expand Up @@ -173,7 +173,7 @@ def jit_f_native(*natives):
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors)
f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors))
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors), cache=True)
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing)
finally:
Expand Down Expand Up @@ -289,21 +289,6 @@ def __init__(self, f, auxiliary_args: Set[str], forget_traces: bool):
self.matrices_and_biases: Dict[SignatureKey, Tuple[SparseCoordinateTensor, Tensor, Tuple]] = {}
self.nl_jit = JitFunction(f, self.auxiliary_args, forget_traces) # for backends that do not support sparse matrices

# def _trace(self, in_key: SignatureKey, prefer_numpy: bool) -> 'ShiftLinTracer':
# assert in_key.shapes[0].is_uniform, f"math.jit_compile_linear() only supports uniform tensors for function input and output but input shape was {in_key.shapes[0]}"
# with NUMPY if prefer_numpy else in_key.backend:
# x = math.ones(in_key.shapes[0])
# tracer = ShiftLinTracer(x, {EMPTY_SHAPE: math.ones()}, x.shape, math.zeros(x.shape))
# _TRACING_JIT.append(self)
# x_kwargs = assemble_tree(in_key.tree, [tracer])
# result = self.f(**x_kwargs, **in_key.auxiliary_kwargs)
# _, result_tensors = disassemble_tree(result)
# assert len(result_tensors) == 1, f"Linear function must return a single Tensor or tensor-like but got {result}"
# result_tensor = result_tensors[0]
# assert isinstance(result_tensor, ShiftLinTracer), f"Tracing linear function '{f_name(self.f)}' failed. Make sure only linear operations are used."
# assert _TRACING_JIT.pop(-1) is self
# return result_tensor

def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict):
if not key.tracing and key in self.matrices_and_biases:
return self.matrices_and_biases[key]
Expand Down Expand Up @@ -502,7 +487,7 @@ def f_native(*natives):
loss_shape = in_key.backend.staticshape(loss_native)
assert len(
loss_shape) == 0, f"Only scalar losses are allowed when returning a native tensor but {f_name(self.f)} returned {type(loss_native).__name__} of shape {loss_shape}. For higher-dimensional values, use Φ-ML tensors instead."
nest, out_tensors = disassemble_tree(result)
nest, out_tensors = disassemble_tree(result, cache=True)
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(f_native, nest, result_shapes, specs, in_key.backend, in_key.tracing)
return loss_native, result_natives
Expand All @@ -521,7 +506,7 @@ def __call__(self, *args, **kwargs):
else:
raise AssertionError(f"jacobian() not supported by {key.backend}.")
wrt_tensors = self._track_wrt(kwargs)
wrt_natives = self._track_wrt_natives(wrt_tensors, disassemble_tree(kwargs)[1])
wrt_natives = self._track_wrt_natives(wrt_tensors, disassemble_tree(kwargs, cache=True)[1])
if key not in self.traces:
self.traces[key] = self._trace_grad(key, wrt_natives)
native_result = self.traces[key](*natives)
Expand All @@ -547,7 +532,7 @@ def __name__(self):
def _track_wrt(self, kwargs: dict):
wrt_tensors = []
for name, arg in kwargs.items():
_, tensors = disassemble_tree(arg)
_, tensors = disassemble_tree(arg, cache=True)
wrt_tensors.extend([name] * len(tensors))
return [t_i for t_i, name in enumerate(wrt_tensors) if name in self._wrt_tuple]

Expand Down Expand Up @@ -666,7 +651,7 @@ def __init__(self, f: Callable, f_params, wrt: tuple, get_output: bool, get_grad
# kwargs = assemble_tree(in_key.tree, in_tensors)
# with functional_derivative_evaluation(order=2):
# result = self.f(**kwargs)
# nest, out_tensors = disassemble_tree(result)
# nest, out_tensors = disassemble_tree(result, cache=True)
# result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
# self.recorded_mappings[in_key] = SignatureKey(f_native, nest, result_shapes, specs, in_key.backend, in_key.tracing)
# return result_natives
Expand Down Expand Up @@ -805,7 +790,7 @@ def forward_native(*natives):
kwargs = assemble_tree(in_key.tree, in_tensors)
ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors")
result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
nest, out_tensors = disassemble_tree(result)
nest, out_tensors = disassemble_tree(result, cache=True)
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(forward_native, nest, result_shapes, specs, in_key.backend, in_key.tracing)
return result_natives
Expand Down Expand Up @@ -1020,7 +1005,7 @@ def map_types(f: Callable, dims: Union[Shape, tuple, list, str, Callable], dim_t
"""

def forward_retype(obj, input_types: Shape):
tree, tensors = disassemble_tree(obj)
tree, tensors = disassemble_tree(obj, cache=False)
retyped = []
for t in tensors:
for dim in t.shape.only(dims):
Expand All @@ -1030,7 +1015,7 @@ def forward_retype(obj, input_types: Shape):
return assemble_tree(tree, retyped), input_types

def reverse_retype(obj, input_types: Shape):
tree, tensors = disassemble_tree(obj)
tree, tensors = disassemble_tree(obj, cache=False)
retyped = []
for t in tensors:
for dim in t.shape.only(input_types.names):
Expand Down
10 changes: 5 additions & 5 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def all_available(*values) -> bool:
Returns:
`True` if no value is a placeholder or being traced, `False` otherwise.
"""
_, tensors = disassemble_tree(values)
_, tensors = disassemble_tree(values, cache=False)
return all([t.available for t in tensors])


Expand Down Expand Up @@ -545,7 +545,7 @@ def zeros(*shape: Shape, dtype: Union[DType, tuple, type] = None) -> Tensor:

def zeros_like(obj: Union[Tensor, PhiTreeNode]) -> Union[Tensor, PhiTreeNode]:
""" Create a `Tensor` containing only `0.0` / `0` / `False` with the same shape and dtype as `obj`. """
nest, values = disassemble_tree(obj)
nest, values = disassemble_tree(obj, cache=False)
zeros_ = []
for val in values:
val = wrap(val)
Expand Down Expand Up @@ -2876,9 +2876,9 @@ def assert_close(*values,
for other in values[1:]:
_assert_close(values[0], other, rel_tolerance, abs_tolerance, msg, verbose)
elif all(isinstance(v, PhiTreeNode) for v in values):
tree0, tensors0 = disassemble_tree(values[0])
tree0, tensors0 = disassemble_tree(values[0], cache=False)
for value in values[1:]:
tree, tensors_ = disassemble_tree(value)
tree, tensors_ = disassemble_tree(value, cache=False)
assert tree0 == tree, f"Tree structures do not match: {tree0} and {tree}"
for t0, t in zip(tensors0, tensors_):
_assert_close(t0, t, rel_tolerance, abs_tolerance, msg, verbose)
Expand Down Expand Up @@ -2974,7 +2974,7 @@ def stop_gradient(x):
if isinstance(x, Tensor):
return x._op1(lambda native: choose_backend(native).stop_gradient(native))
elif isinstance(x, PhiTreeNode):
nest, values = disassemble_tree(x)
nest, values = disassemble_tree(x, cache=False)
new_values = [stop_gradient(v) for v in values]
return assemble_tree(nest, new_values)
else:
Expand Down
24 changes: 12 additions & 12 deletions phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _default_solve_info_msg(msg: str, converged: bool, diverged: bool, iteration
if diverged:
return f"Solve diverged within {iterations if iterations is not None else '?'} iterations using {method}."
elif not converged:
max_res = [f"{math.max_(t.trajectory[-1]):no-color:no-dtype}" for t in disassemble_tree(residual)[1]]
max_res = [f"{math.max_(t.trajectory[-1]):no-color:no-dtype}" for t in disassemble_tree(residual, cache=False)[1]]
return f"{method} did not converge to rel_tol={float(solve.rel_tol):.0e}, abs_tol={float(solve.abs_tol):.0e} within {int(solve.max_iterations)} iterations. Max residual: {', '.join(max_res)}"
else:
return f"Converged within {iterations if iterations is not None else '?'} iterations."
Expand Down Expand Up @@ -364,7 +364,7 @@ def minimize(f: Callable[[X], Y], solve: Solve[X, Y]) -> X:
solve = solve.with_defaults('optimization')
assert (solve.rel_tol == 0).all, f"rel_tol must be zero for minimize() but got {solve.rel_tol}"
assert solve.preprocess_y is None, "minimize() does not allow preprocess_y"
x0_nest, x0_tensors = disassemble_tree(solve.x0)
x0_nest, x0_tensors = disassemble_tree(solve.x0, cache=True)
x0_tensors = [to_float(t) for t in x0_tensors]
backend = choose_backend_t(*x0_tensors, prefer_default=True)
batch_dims = merge_shapes(*[t.shape for t in x0_tensors]).batch
Expand Down Expand Up @@ -408,7 +408,7 @@ def native_function(x_flat):
y = f(*x)
else:
y = f(x)
_, y_tensors = disassemble_tree(y)
_, y_tensors = disassemble_tree(y, cache=False)
assert not non_batch(y_tensors[0]), f"Failed to minimize '{f.__name__}' because it returned a non-scalar output {shape(y_tensors[0])}. Reduce all non-batch dimensions, e.g. using math.l2_loss()"
if y_tensors[0].shape.without(batch_dims): # output added more batch dims. We should expand the initial guess
raise NewBatchDims(y_tensors[0].shape, y_tensors[0].shape.without(batch_dims))
Expand Down Expand Up @@ -559,8 +559,8 @@ def solve_linear(f: Union[Callable[[X], Y], Tensor],
f_kwargs.update(f_kwargs_)
f_args = f_args[0] if len(f_args) == 1 and isinstance(f_args[0], tuple) else f_args
# --- Get input and output tensors ---
y_tree, y_tensors = disassemble_tree(y)
x0_tree, x0_tensors = disassemble_tree(solve.x0)
y_tree, y_tensors = disassemble_tree(y, cache=False)
x0_tree, x0_tensors = disassemble_tree(solve.x0, cache=False)
assert solve.x0 is not None, "Please specify the initial guess as Solve(..., x0=initial_guess)"
assert len(x0_tensors) == len(y_tensors) == 1, "Only single-tensor linear solves are currently supported"
if y_tree == 'native' and x0_tree == 'native':
Expand Down Expand Up @@ -626,8 +626,8 @@ def _matrix_solve_forward(y, solve: Solve, matrix: Tensor, is_backprop=False):
assert solve.preconditioner is None, f"Preconditioners not currently supported for matrix-free solves. Decorate '{f_name(f)}' with @math.jit_compile_linear to perform a matrix solve."

def _function_solve_forward(y, solve: Solve, f_args: tuple, f_kwargs: dict = None, is_backprop=False):
y_nest, (y_tensor,) = disassemble_tree(y)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0)
y_nest, (y_tensor,) = disassemble_tree(y, cache=False)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False)
# active_dims = (y_tensor.shape & x0_tensor.shape).non_batch # assumes batch dimensions are not active
batches = (y_tensor.shape & x0_tensor.shape).batch

Expand All @@ -636,7 +636,7 @@ def native_lin_f(native_x, batch_index=None):
native_x = backend.tile(backend.expand_dims(native_x), [batches.volume, 1])
x = assemble_tree(x0_nest, [reshaped_tensor(native_x, [batches, non_batch(x0_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(x0_tensor)], convert=False)])
y = f(x, *f_args, **f_kwargs)
_, (y_tensor,) = disassemble_tree(y)
_, (y_tensor,) = disassemble_tree(y, cache=False)
y_native = reshaped_native(y_tensor, [batches, non_batch(y_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(y_tensor)])
if batch_index is not None and batches.volume > 1:
y_native = y_native[batch_index]
Expand All @@ -661,8 +661,8 @@ def _linear_solve_forward(y: Tensor,
ML_LOGGER.debug(f"Performing linear solve {solve} with backend {backend}")
if solve.preprocess_y is not None:
y = solve.preprocess_y(y, *solve.preprocess_y_args)
y_nest, (y_tensor,) = disassemble_tree(y)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0)
y_nest, (y_tensor,) = disassemble_tree(y, cache=False)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False)
pattern_dims_in = x0_tensor.shape.only(pattern_dims_in, reorder=True)
if pattern_dims_out not in y_tensor.shape:
warnings.warn(f"right-hand-side has shape {y_tensor.shape} but output dimensions are {pattern_dims_out}. This may result in unexpected behavior", RuntimeWarning, stacklevel=3)
Expand Down Expand Up @@ -729,8 +729,8 @@ def implicit_gradient_solve(fwd_args: dict, x, dx):
if isinstance(matrix, SparseCoordinateTensor):
col = matrix.dual_indices(to_primal=True)
row = matrix.primal_indices()
_, dy_tensors = disassemble_tree(dy)
_, x_tensors = disassemble_tree(x)
_, dy_tensors = disassemble_tree(dy, cache=False)
_, x_tensors = disassemble_tree(x, cache=False)
dm_values = dy_tensors[0][col] * x_tensors[0][row]
dm = matrix._with_values(dm_values)
dm = -dm
Expand Down
16 changes: 8 additions & 8 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def assemble_tensors(natives: Union[tuple, list], specs: Union[Tuple[dict, ...],
NATIVE_TENSOR = 'native'


def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor]]:
def disassemble_tree(obj: PhiTreeNodeType, cache: bool) -> Tuple[PhiTreeNodeType, List[Tensor]]:
"""
Splits a nested structure of Tensors into the structure without the tensors and an ordered list of tensors.
Native tensors will be wrapped in phiml.math.Tensors with default dimension names and dimension types `None`.
Expand All @@ -1873,6 +1873,7 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor
Args:
obj: Nested structure of `Tensor` objects.
Nested structures include: `tuple`, `list`, `dict`, `phiml.math.magic.PhiTreeNode`.
cache: Whether to return cached versions of the tensors. This may reduce the number of native tensors required.
Returns:
empty structure: Same structure as `obj` but with the tensors replaced by `None`.
Expand All @@ -1881,20 +1882,20 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor
if obj is None:
return MISSING_TENSOR, []
elif isinstance(obj, Tensor):
return None, [obj]
return None, [cached(obj) if cache else obj]
elif isinstance(obj, (tuple, list)):
keys = []
values = []
for item in obj:
key, value = disassemble_tree(item)
key, value = disassemble_tree(item, cache)
keys.append(key)
values.extend(value)
return (tuple(keys) if isinstance(obj, tuple) else keys), values
elif isinstance(obj, dict):
keys = {}
values = []
for name, item in obj.items():
key, value = disassemble_tree(item)
key, value = disassemble_tree(item, cache)
keys[name] = key
values.extend(value)
return keys, values
Expand All @@ -1903,7 +1904,7 @@ def disassemble_tree(obj: PhiTreeNodeType) -> Tuple[PhiTreeNodeType, List[Tensor
keys = {}
values = []
for attr in attributes:
key, value = disassemble_tree(getattr(obj, attr))
key, value = disassemble_tree(getattr(obj, attr), cache)
keys[attr] = key
values.extend(value)
return copy_with(obj, **keys), values
Expand Down Expand Up @@ -1967,9 +1968,8 @@ def cached(t: TensorOrTree) -> TensorOrTree:
elif isinstance(t, Layout):
return t
elif isinstance(t, PhiTreeNode):
tree, tensors = disassemble_tree(t)
tensors_ = [cached(t_) for t_ in tensors]
return assemble_tree(tree, tensors_)
tree, tensors = disassemble_tree(t, cache=True)
return assemble_tree(tree, tensors)
else:
raise AssertionError(f"Cannot cache {type(t)} {t}")

Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,15 @@ def matrix_from_function(f: Callable,
all_args = {**kwargs, **{f_params[i]: v for i, v in enumerate(args)}}
aux_args = {k: v for k, v in all_args.items() if k in aux}
trace_args = {k: v for k, v in all_args.items() if k not in aux}
tree, tensors = disassemble_tree(trace_args)
tree, tensors = disassemble_tree(trace_args, cache=False)
target_backend = choose_backend_t(*tensors)
# --- Trace function ---
with NUMPY:
src = TracerSource(tensors[0].shape, tensors[0].dtype, tuple(trace_args.keys())[0], 0)
tracer = ShiftLinTracer(src, {EMPTY_SHAPE: math.ones()}, tensors[0].shape, bias=math.zeros(dtype=tensors[0].dtype), renamed={d: d for d in tensors[0].shape.names})
x_kwargs = assemble_tree(tree, [tracer])
result = f(**x_kwargs, **aux_args)
out_tree, result_tensors = disassemble_tree(result)
out_tree, result_tensors = disassemble_tree(result, cache=False)
assert len(result_tensors) == 1, f"Linear function output must be or contain a single Tensor but got {result}"
tracer = result_tensors[0]._simplify()
assert tracer._is_tracer, f"Tracing linear function '{f_name(f)}' failed. Make sure only linear operations are used. Output: {tracer.shape}"
Expand Down

0 comments on commit 699571e

Please sign in to comment.