Skip to content

Commit

Permalink
Fix jitted scipy-lsqr
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 10, 2025
1 parent 2440bd6 commit a408567
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion phiml/backend/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors):
fp = b.float_type
i = INT32
bo = BOOL
x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, x0.shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors)
rsd_shape = list(x0.shape)
if was_row_added:
rsd_shape[1] += 1
x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, rsd_shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors)
if was_row_added:
residual = residual[:, :-1]
return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size)
Expand Down

0 comments on commit a408567

Please sign in to comment.