Skip to content

Commit

Permalink
Fix solve_linear with multi-dual vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 20, 2025
1 parent 5bf9d58 commit 144aae4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def _linear_solve_forward(y: Tensor,
assert isinstance(ret, SolveResult)
converged = reshaped_tensor(ret.converged, [*trj_dims, batch_dims])
diverged = reshaped_tensor(ret.diverged, [*trj_dims, batch_dims])
x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])], attr_type=variable_attributes)
x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_in])], attr_type=variable_attributes)
final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])], attr_type=variable_attributes)
iterations = reshaped_tensor(ret.iterations, [*trj_dims, batch_dims])
function_evaluations = reshaped_tensor(ret.function_evaluations, [*trj_dims, batch_dims])
Expand Down

0 comments on commit 144aae4

Please sign in to comment.