Skip to content

Commit

Permalink
Optimize to use variables of x, values of y
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl authored and holl- committed Oct 25, 2024
1 parent 7641b1d commit 22929ac
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..backend._backend import SolveResult, ML_LOGGER, default_backend, convert, Preconditioner, choose_backend
from ..backend._linalg import IncompleteLU, incomplete_lu_dense, incomplete_lu_coo, coarse_explicit_preconditioner_coo
from ._shape import EMPTY_SHAPE, Shape, merge_shapes, batch, non_batch, shape, dual, channel, non_dual, instance, spatial
from ._magic_ops import stack, copy_with, rename_dims, unpack_dim, unstack, expand, value_attributes
from ._magic_ops import stack, copy_with, rename_dims, unpack_dim, unstack, expand, value_attributes, variable_attributes
from ._sparse import native_matrix, SparseCoordinateTensor, CompressedSparseMatrix, stored_values, is_sparse, matrix_rank, _stored_matrix_rank
from ._tensors import Tensor, disassemble_tree, assemble_tree, wrap, cached, NativeTensor, layout, reshaped_numpy, reshaped_native, reshaped_tensor, NATIVE_TENSOR
from . import _ops as math
Expand Down Expand Up @@ -370,7 +370,7 @@ def minimize(f: Callable[[X], Y], solve: Solve[X, Y]) -> X:
solve = solve.with_defaults('optimization')
assert (solve.rel_tol == 0).all, f"rel_tol must be zero for minimize() but got {solve.rel_tol}"
assert solve.preprocess_y is None, "minimize() does not allow preprocess_y"
x0_nest, x0_tensors = disassemble_tree(solve.x0, cache=True, attr_type=value_attributes)
x0_nest, x0_tensors = disassemble_tree(solve.x0, cache=True, attr_type=variable_attributes)
x0_tensors = [to_float(t) for t in x0_tensors]
backend = choose_backend_t(*x0_tensors, prefer_default=True)
batch_dims = merge_shapes(*[batch(t) for t in x0_tensors])
Expand Down Expand Up @@ -405,7 +405,7 @@ def unflatten_assemble(x_flat, additional_dims: Shape = EMPTY_SHAPE, convert=Tru
stack_dims = t.shape.shape.without('dims')
x_tensors.append(stack(partial_tensors[:stack_dims.volume], stack_dims))
partial_tensors = partial_tensors[stack_dims.volume:]
x = assemble_tree(x0_nest, x_tensors, attr_type=value_attributes)
x = assemble_tree(x0_nest, x_tensors, attr_type=variable_attributes)
return x

def native_function(x_flat):
Expand Down Expand Up @@ -572,7 +572,7 @@ def solve_linear(f: Union[Callable[[X], Y], Tensor],
f_args = f_args[0] if len(f_args) == 1 and isinstance(f_args[0], tuple) else f_args
# --- Get input and output tensors ---
y_tree, y_tensors = disassemble_tree(y, cache=False, attr_type=value_attributes)
x0_tree, x0_tensors = disassemble_tree(solve.x0, cache=False, attr_type=value_attributes)
x0_tree, x0_tensors = disassemble_tree(solve.x0, cache=False, attr_type=variable_attributes)
assert len(x0_tensors) == len(y_tensors) == 1, "Only single-tensor linear solves are currently supported"
if isinstance(y_tree, str) and y_tree == NATIVE_TENSOR and isinstance(x0_tree, str) and x0_tree == NATIVE_TENSOR:
if callable(f): # assume batch + 1 dim
Expand Down Expand Up @@ -651,14 +651,14 @@ def _matrix_solve_forward(y, solve: Solve, matrix: Tensor, is_backprop=False):

def _function_solve_forward(y, solve: Solve, f_args: tuple, f_kwargs: dict = None, is_backprop=False):
y_nest, (y_tensor,) = disassemble_tree(y, cache=False, attr_type=value_attributes)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False, attr_type=value_attributes)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False, attr_type=variable_attributes)
# active_dims = (y_tensor.shape & x0_tensor.shape).non_batch # assumes batch dimensions are not active
batches = (y_tensor.shape & x0_tensor.shape).batch

def native_lin_f(native_x, batch_index=None):
if batch_index is not None and batches.volume > 1:
native_x = backend.tile(backend.expand_dims(native_x), [batches.volume, 1])
x = assemble_tree(x0_nest, [reshaped_tensor(native_x, [batches, non_batch(x0_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(x0_tensor)], convert=False)], attr_type=value_attributes)
x = assemble_tree(x0_nest, [reshaped_tensor(native_x, [batches, non_batch(x0_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(x0_tensor)], convert=False)], attr_type=variable_attributes)
y_ = f(x, *f_args, **f_kwargs)
_, (y_tensor_,) = disassemble_tree(y_, cache=False, attr_type=value_attributes)
assert set(non_batch(y_tensor_)) == set(non_batch(y_tensor)), f"Function returned dimensions {y_tensor_.shape} but right-hand-side has shape {y_tensor.shape}"
Expand Down Expand Up @@ -687,7 +687,7 @@ def _linear_solve_forward(y: Tensor,
if solve.preprocess_y is not None:
y = solve.preprocess_y(y, *solve.preprocess_y_args)
y_nest, (y_tensor,) = disassemble_tree(y, cache=False, attr_type=value_attributes)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False, attr_type=value_attributes)
x0_nest, (x0_tensor,) = disassemble_tree(solve.x0, cache=False, attr_type=variable_attributes)
pattern_dims_in = x0_tensor.shape.only(pattern_dims_in, reorder=True)
if pattern_dims_out not in y_tensor.shape:
warnings.warn(f"right-hand-side has shape {y_tensor.shape} but output dimensions are {pattern_dims_out}. This may result in unexpected behavior", RuntimeWarning, stacklevel=3)
Expand Down Expand Up @@ -726,8 +726,8 @@ def _linear_solve_forward(y: Tensor,
assert isinstance(ret, SolveResult)
converged = reshaped_tensor(ret.converged, [*trj_dims, batch_dims])
diverged = reshaped_tensor(ret.diverged, [*trj_dims, batch_dims])
x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])], attr_type=value_attributes)
final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])], attr_type=value_attributes)
x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])], attr_type=variable_attributes)
final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])], attr_type=variable_attributes)
iterations = reshaped_tensor(ret.iterations, [*trj_dims, batch_dims])
function_evaluations = reshaped_tensor(ret.function_evaluations, [*trj_dims, batch_dims])
if ret.residual is not None:
Expand Down Expand Up @@ -765,7 +765,7 @@ def implicit_gradient_solve(fwd_args: dict, x, dx):
col = matrix.dual_indices(to_primal=True)
row = matrix.primal_indices()
_, dy_tensors = disassemble_tree(dy, cache=False, attr_type=value_attributes)
_, x_tensors = disassemble_tree(x, cache=False, attr_type=value_attributes)
_, x_tensors = disassemble_tree(x, cache=False, attr_type=variable_attributes)
dm_values = dy_tensors[0][col] * x_tensors[0][row]
dm = matrix._with_values(dm_values)
dm = -dm
Expand Down

0 comments on commit 22929ac

Please sign in to comment.