Skip to content

Commit

Permalink
Fix matrix gradient in linear_solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 22, 2023
1 parent 6785d90 commit 2c4cbdc
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def _linear_solve_forward(y: Tensor,
converged = reshaped_tensor(ret.converged, [*trj_dims, batch_dims])
diverged = reshaped_tensor(ret.diverged, [*trj_dims, batch_dims])
x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])])
final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])])
iterations = reshaped_tensor(ret.iterations, [*trj_dims, batch_dims])
function_evaluations = reshaped_tensor(ret.function_evaluations, [*trj_dims, batch_dims])
if ret.residual is not None:
Expand All @@ -674,19 +675,10 @@ def _linear_solve_forward(y: Tensor,
residual = None
msg = unpack_dim(layout(ret.message, batch('_all')), '_all', batch_dims)
result = SolveInfo(solve, x, residual, iterations, function_evaluations, converged, diverged, ret.method, msg, t)
# else: # trajectory
# converged = reshaped_tensor(ret[-1].converged, [batch_dims])
# diverged = reshaped_tensor(ret[-1].diverged, [batch_dims])
# x = assemble_tree(x0_nest, [reshaped_tensor(ret[-1].x, [batch_dims, pattern_dims_in])])
# x_ = assemble_tree(x0_nest, [stack([reshaped_tensor(r.x, [batch_dims, pattern_dims_in]) for r in ret], )])
# residual = assemble_tree(y_nest, [stack([reshaped_tensor(r.residual, [batch_dims, pattern_dims_out]) for r in ret], batch('trajectory'))])
# iterations = reshaped_tensor(ret[-1].iterations, [batch_dims])
# function_evaluations = stack([reshaped_tensor(r.function_evaluations, [batch_dims]) for r in ret], batch('trajectory'))
# result = SolveInfo(solve, x_, residual, iterations, function_evaluations, converged, diverged, ret[-1].method, ret[-1].message, t)
for tape in _SOLVE_TAPES:
tape._add(solve, trj, result)
result.convergence_check(is_backprop and 'TensorFlow' in backend.name) # raises ConvergenceException
return x[{'trajectory': -1}] if isinstance(x, Tensor) else x
return final_x


def attach_gradient_solve(forward_solve: Callable, auxiliary_args: str, matrix_adjoint: bool):
Expand All @@ -708,8 +700,11 @@ def implicit_gradient_solve(fwd_args: dict, x, dx):
if isinstance(matrix, SparseCoordinateTensor):
col = matrix.dual_indices(to_primal=True)
row = matrix.primal_indices()
dm_values = dy[col] * x[row]
_, dy_tensors = disassemble_tree(dy)
_, x_tensors = disassemble_tree(x)
dm_values = dy_tensors[0][col] * x_tensors[0][row]
dm = matrix._with_values(dm_values)
dm = -dm
elif isinstance(matrix, NativeTensor):
dy_dual = rename_dims(dy, shape(dy), shape(dy).as_dual())
dm = dy_dual * x # outer product
Expand Down

0 comments on commit 2c4cbdc

Please sign in to comment.