Skip to content

Commit

Permalink
Add SciPy lsqr linear solver
Browse files Browse the repository at this point in the history
* Allow matrix_offset for COO matrices when using lsqr. This adds a new row to the matrix.
  • Loading branch information
holl- committed Jan 4, 2025
1 parent 4946252 commit 40f09cf
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
72 changes: 57 additions & 15 deletions phiml/backend/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,28 @@ def loop_body(continue_, x, residual, iterations, function_evaluations, _converg

def scipy_sparse_solve(b: Backend, method: Union[str, Callable], lin, y, x0, rtol, atol, max_iter, pre: Optional[Preconditioner], matrix_offset) -> SolveResult:
assert max_iter.shape[0] == 1, f"Trajectory recording not supported for scipy_spsolve"
was_row_added = False
if matrix_offset is not None:
raise NotImplementedError(f"matrix offset not yet supported by sparse scipy solvers")
if not callable(lin):
sparse_format = b.get_sparse_format(lin)
if sparse_format == 'coo':
_, (indices, values) = b.disassemble(lin)
rows, cols = b.shape(lin)
new_row = b.zeros((cols,), b.dtype(indices)) + rows
all_cols = b.range(0, cols, dtype=b.dtype(indices))
new_indices = b.stack([new_row, all_cols], -1)
indices = b.concat([indices, new_indices], 0)
values = b.concat([values, b.zeros((cols,), b.dtype(values)) + matrix_offset], 0)
lin = b.sparse_coo_tensor(indices, values, (rows+1, cols))
y = b.pad(y, [(0, 0), (0, 1)])
was_row_added = True
if method != 'lsqr':
warnings.warn(f"Using scipy.sparse.linalg.lsqr instead of '{method}' to account for regularization")
method = 'lsqr'
else:
raise NotImplementedError(f"matrix offset not yet supported by sparse scipy solvers for '{sparse_format}' matrices")
else:
raise NotImplementedError(f"matrix offset not yet supported by matrix-free sparse scipy solvers")
if method == 'direct' and pre:
warnings.warn(f"Preconditioner {pre} was computed but is not used by SciPy direct solve.", RuntimeWarning)
scipy_solvers = {
Expand All @@ -289,6 +309,7 @@ def scipy_sparse_solve(b: Backend, method: Union[str, Callable], lin, y, x0, rto
'lGMres': scipy.sparse.linalg.lgmres,
'QMR': scipy.sparse.linalg.qmr,
'GCrotMK': scipy.sparse.linalg.gcrotmk,
'lsqr': scipy.sparse.linalg.lsqr,
# 'minres': scipy.sparse.linalg.minres, # this does not work like the others
}
function = scipy_solvers[method] if isinstance(method, str) and method != 'direct' else method
Expand Down Expand Up @@ -318,6 +339,8 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors):
i = DType(int, 32)
bo = DType(bool)
x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, x0.shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors)
if was_row_added:
residual = residual[:, :-1]
return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size)


Expand Down Expand Up @@ -357,28 +380,47 @@ def scipy_iterative_sparse_solve(b: Backend, lin, y, x0, rtol, atol, max_iter, p
# A = LinearOperator(dtype=y.dtype, shape=(self.staticshape(y)[-1], self.staticshape(x0)[-1]), matvec=A)

def count_callback(x_n): # called after each step, not with x0
iterations[b] += 1
iterations[bi] += 1

xs = []
iterations = [0] * batch_size
converged = []
diverged = []
residual = []
messages = []
for b in range(batch_size):
lin_b = lin[min(b, len(lin)-1)] if isinstance(lin, (tuple, list)) or (isinstance(lin, np.ndarray) and len(lin.shape) > 2) else lin
for bi in range(batch_size):
lin_b = lin[min(bi, len(lin) - 1)] if isinstance(lin, (tuple, list)) or (isinstance(lin, np.ndarray) and len(lin.shape) > 2) else lin
pre_op = LinearOperator(shape=lin_b.shape, matvec=pre.apply, rmatvec=pre.apply_transposed) if isinstance(pre, Preconditioner) else None
lin_b = LinearOperator(shape=y[b].shape + x0[b].shape, matvec=lin_b) if callable(lin_b) else lin_b
try:
x, ret_val = scipy_function(lin_b, y[b], x0=x0[b], rtol=rtol[b], atol=atol[b], maxiter=max_iter[-1, b], M=pre_op, callback=count_callback)
except TypeError:
x, ret_val = scipy_function(lin_b, y[b], x0=x0[b], tol=rtol[b], atol=atol[b], maxiter=max_iter[-1, b], M=pre_op, callback=count_callback) # for old SciPy versions
# ret_val: 0=success, >0=not converged, <0=error
messages.append(f"code {ret_val} (SciPy {scipy_function.__name__})")
xs.append(x)
converged.append(ret_val == 0)
diverged.append(ret_val < 0 or np.any(~np.isfinite(x)))
residual.append(lin_b @ x - y[b])
lin_b = LinearOperator(shape=y[bi].shape + x0[bi].shape, matvec=lin_b) if callable(lin_b) else lin_b
if scipy_function == scipy.sparse.linalg.lsqr:
assert b.all(atol == 0), f"scipy sparse lsqr does not support absolute tolerance. Please set it to 0."
if pre_op is not None:
warnings.warn(f"scipy sparse lsqr does not support preconditioners. Ignoring {pre}")
x, ret_val, n_iter, l1norm, l2norm, anorm, acond, arnorm, xnorm, _ = scipy_function(lin_b, y[bi], x0=x0[bi], atol=rtol[bi], btol=rtol[bi], iter_lim=max_iter[-1, bi])
if ret_val == 0:
messages.append("trivial solution")
elif ret_val == 1:
messages.append("x is an approximate solution to the linear system")
elif ret_val == 2:
messages.append("x approximately solves the least-squares problem")
else:
messages.append(f"Least squares problem not solved (code {ret_val})")
xs.append(x)
converged.append(ret_val in {0, 1, 2})
diverged.append(ret_val not in {0, 1, 2})
iterations[bi] = n_iter
residual.append(lin_b @ x - y[bi])
else:
try:
x, ret_val = scipy_function(lin_b, y[bi], x0=x0[bi], rtol=rtol[bi], atol=atol[bi], maxiter=max_iter[-1, bi], M=pre_op, callback=count_callback)
except TypeError:
x, ret_val = scipy_function(lin_b, y[bi], x0=x0[bi], tol=rtol[bi], atol=atol[bi], maxiter=max_iter[-1, bi], M=pre_op, callback=count_callback) # for old SciPy versions
# ret_val: 0=success, >0=not converged, <0=error
messages.append(f"code {ret_val} (SciPy {scipy_function.__name__})")
xs.append(x)
converged.append(ret_val == 0)
diverged.append(ret_val < 0 or np.any(~np.isfinite(x)))
residual.append(lin_b @ x - y[bi])
x = np.stack(xs).astype(to_numpy_dtype(dtype))
residual = np.stack(residual).astype(to_numpy_dtype(dtype))
iterations = np.asarray(iterations, np.int32)
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def solve_linear(f: Union[Callable[[X], Y], Tensor],
* `'biCG-stab'` or `'biCG-stab(1)'`: Biconjugate gradient stabilized, first order
* `'biCG-stab(2)'`, `'biCG-stab(4)'`, ...: Biconjugate gradient stabilized, second or higher order
* `'scipy-direct'`: SciPy direct solve always run oh the CPU using `scipy.sparse.linalg.spsolve`.
* `'scipy-CG'`, `'scipy-GMres'`, `'scipy-biCG'`, `'scipy-biCG-stab'`, `'scipy-CGS'`, `'scipy-QMR'`, `'scipy-GCrotMK'`: SciPy iterative solvers always run oh the CPU, both in eager execution and JIT mode.
* `'scipy-CG'`, `'scipy-GMres'`, `'scipy-biCG'`, `'scipy-biCG-stab'`, `'scipy-CGS'`, `'scipy-QMR'`, `'scipy-GCrotMK'`, `'scipy-lsqr'`: SciPy iterative solvers always run oh the CPU, both in eager execution and JIT mode.
For maximum performance, compile `f` using `jit_compile_linear()` beforehand.
Then, an optimized representation of `f` (such as a sparse matrix) will be used to solve the linear system.
Expand Down

0 comments on commit 40f09cf

Please sign in to comment.