From 2c4cbdc90a50d9fc72cff327db8472abc9bed50c Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Fri, 22 Dec 2023 20:54:37 +0100 Subject: [PATCH] Fix matrix gradient in linear_solve() --- phiml/math/_optimize.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/phiml/math/_optimize.py b/phiml/math/_optimize.py index ab4330f5..edd2eb0c 100644 --- a/phiml/math/_optimize.py +++ b/phiml/math/_optimize.py @@ -663,6 +663,7 @@ def _linear_solve_forward(y: Tensor, converged = reshaped_tensor(ret.converged, [*trj_dims, batch_dims]) diverged = reshaped_tensor(ret.diverged, [*trj_dims, batch_dims]) x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])]) + final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])]) iterations = reshaped_tensor(ret.iterations, [*trj_dims, batch_dims]) function_evaluations = reshaped_tensor(ret.function_evaluations, [*trj_dims, batch_dims]) if ret.residual is not None: @@ -674,19 +675,10 @@ def _linear_solve_forward(y: Tensor, residual = None msg = unpack_dim(layout(ret.message, batch('_all')), '_all', batch_dims) result = SolveInfo(solve, x, residual, iterations, function_evaluations, converged, diverged, ret.method, msg, t) - # else: # trajectory - # converged = reshaped_tensor(ret[-1].converged, [batch_dims]) - # diverged = reshaped_tensor(ret[-1].diverged, [batch_dims]) - # x = assemble_tree(x0_nest, [reshaped_tensor(ret[-1].x, [batch_dims, pattern_dims_in])]) - # x_ = assemble_tree(x0_nest, [stack([reshaped_tensor(r.x, [batch_dims, pattern_dims_in]) for r in ret], )]) - # residual = assemble_tree(y_nest, [stack([reshaped_tensor(r.residual, [batch_dims, pattern_dims_out]) for r in ret], batch('trajectory'))]) - # iterations = reshaped_tensor(ret[-1].iterations, [batch_dims]) - # function_evaluations = stack([reshaped_tensor(r.function_evaluations, [batch_dims]) for r in ret], batch('trajectory')) - # result = SolveInfo(solve, x_, residual, iterations, function_evaluations, converged, diverged, ret[-1].method, ret[-1].message, t) for tape in _SOLVE_TAPES: tape._add(solve, trj, result) result.convergence_check(is_backprop and 'TensorFlow' in backend.name) # raises ConvergenceException - return x[{'trajectory': -1}] if isinstance(x, Tensor) else x + return final_x def attach_gradient_solve(forward_solve: Callable, auxiliary_args: str, matrix_adjoint: bool): @@ -708,8 +700,11 @@ def implicit_gradient_solve(fwd_args: dict, x, dx): if isinstance(matrix, SparseCoordinateTensor): col = matrix.dual_indices(to_primal=True) row = matrix.primal_indices() - dm_values = dy[col] * x[row] + _, dy_tensors = disassemble_tree(dy) + _, x_tensors = disassemble_tree(x) + dm_values = dy_tensors[0][col] * x_tensors[0][row] dm = matrix._with_values(dm_values) + dm = -dm elif isinstance(matrix, NativeTensor): dy_dual = rename_dims(dy, shape(dy), shape(dy).as_dual()) dm = dy_dual * x # outer product