From 40f09cfc73547196f16554cb875755ecd842690d Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sat, 4 Jan 2025 23:00:41 +0100 Subject: [PATCH] Add SciPy lsqr linear solver * Allow matrix_offset for COO matrices when using lsqr. This adds a new row to the matrix. --- phiml/backend/_linalg.py | 72 +++++++++++++++++++++++++++++++--------- phiml/math/_optimize.py | 2 +- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/phiml/backend/_linalg.py b/phiml/backend/_linalg.py index 4b0b765..394a534 100644 --- a/phiml/backend/_linalg.py +++ b/phiml/backend/_linalg.py @@ -276,8 +276,28 @@ def loop_body(continue_, x, residual, iterations, function_evaluations, _converg def scipy_sparse_solve(b: Backend, method: Union[str, Callable], lin, y, x0, rtol, atol, max_iter, pre: Optional[Preconditioner], matrix_offset) -> SolveResult: assert max_iter.shape[0] == 1, f"Trajectory recording not supported for scipy_spsolve" + was_row_added = False if matrix_offset is not None: - raise NotImplementedError(f"matrix offset not yet supported by sparse scipy solvers") + if not callable(lin): + sparse_format = b.get_sparse_format(lin) + if sparse_format == 'coo': + _, (indices, values) = b.disassemble(lin) + rows, cols = b.shape(lin) + new_row = b.zeros((cols,), b.dtype(indices)) + rows + all_cols = b.range(0, cols, dtype=b.dtype(indices)) + new_indices = b.stack([new_row, all_cols], -1) + indices = b.concat([indices, new_indices], 0) + values = b.concat([values, b.zeros((cols,), b.dtype(values)) + matrix_offset], 0) + lin = b.sparse_coo_tensor(indices, values, (rows+1, cols)) + y = b.pad(y, [(0, 0), (0, 1)]) + was_row_added = True + if method != 'lsqr': + warnings.warn(f"Using scipy.sparse.linalg.lsqr instead of '{method}' to account for regularization") + method = 'lsqr' + else: + raise NotImplementedError(f"matrix offset not yet supported by sparse scipy solvers for '{sparse_format}' matrices") + else: + raise NotImplementedError(f"matrix offset not yet supported by matrix-free sparse scipy solvers") if method == 'direct' and pre: warnings.warn(f"Preconditioner {pre} was computed but is not used by SciPy direct solve.", RuntimeWarning) scipy_solvers = { @@ -289,6 +309,7 @@ def scipy_sparse_solve(b: Backend, method: Union[str, Callable], lin, y, x0, rto 'lGMres': scipy.sparse.linalg.lgmres, 'QMR': scipy.sparse.linalg.qmr, 'GCrotMK': scipy.sparse.linalg.gcrotmk, + 'lsqr': scipy.sparse.linalg.lsqr, # 'minres': scipy.sparse.linalg.minres, # this does not work like the others } function = scipy_solvers[method] if isinstance(method, str) and method != 'direct' else method @@ -318,6 +339,8 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors): i = DType(int, 32) bo = DType(bool) x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, x0.shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors) + if was_row_added: + residual = residual[:, :-1] return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size) @@ -357,7 +380,7 @@ def scipy_iterative_sparse_solve(b: Backend, lin, y, x0, rtol, atol, max_iter, p # A = LinearOperator(dtype=y.dtype, shape=(self.staticshape(y)[-1], self.staticshape(x0)[-1]), matvec=A) def count_callback(x_n): # called after each step, not with x0 - iterations[b] += 1 + iterations[bi] += 1 xs = [] iterations = [0] * batch_size @@ -365,20 +388,39 @@ def count_callback(x_n): # called after each step, not with x0 diverged = [] residual = [] messages = [] - for b in range(batch_size): - lin_b = lin[min(b, len(lin)-1)] if isinstance(lin, (tuple, list)) or (isinstance(lin, np.ndarray) and len(lin.shape) > 2) else lin + for bi in range(batch_size): + lin_b = lin[min(bi, len(lin) - 1)] if isinstance(lin, (tuple, list)) or (isinstance(lin, np.ndarray) and len(lin.shape) > 2) else lin pre_op = LinearOperator(shape=lin_b.shape, matvec=pre.apply, rmatvec=pre.apply_transposed) if isinstance(pre, Preconditioner) else None - lin_b = LinearOperator(shape=y[b].shape + x0[b].shape, matvec=lin_b) if callable(lin_b) else lin_b - try: - x, ret_val = scipy_function(lin_b, y[b], x0=x0[b], rtol=rtol[b], atol=atol[b], maxiter=max_iter[-1, b], M=pre_op, callback=count_callback) - except TypeError: - x, ret_val = scipy_function(lin_b, y[b], x0=x0[b], tol=rtol[b], atol=atol[b], maxiter=max_iter[-1, b], M=pre_op, callback=count_callback) # for old SciPy versions - # ret_val: 0=success, >0=not converged, <0=error - messages.append(f"code {ret_val} (SciPy {scipy_function.__name__})") - xs.append(x) - converged.append(ret_val == 0) - diverged.append(ret_val < 0 or np.any(~np.isfinite(x))) - residual.append(lin_b @ x - y[b]) + lin_b = LinearOperator(shape=y[bi].shape + x0[bi].shape, matvec=lin_b) if callable(lin_b) else lin_b + if scipy_function == scipy.sparse.linalg.lsqr: + assert b.all(atol == 0), f"scipy sparse lsqr does not support absolute tolerance. Please set it to 0." + if pre_op is not None: + warnings.warn(f"scipy sparse lsqr does not support preconditioners. Ignoring {pre}") + x, ret_val, n_iter, l1norm, l2norm, anorm, acond, arnorm, xnorm, _ = scipy_function(lin_b, y[bi], x0=x0[bi], atol=rtol[bi], btol=rtol[bi], iter_lim=max_iter[-1, bi]) + if ret_val == 0: + messages.append("trivial solution") + elif ret_val == 1: + messages.append("x is an approximate solution to the linear system") + elif ret_val == 2: + messages.append("x approximately solves the least-squares problem") + else: + messages.append(f"Least squares problem not solved (code {ret_val})") + xs.append(x) + converged.append(ret_val in {0, 1, 2}) + diverged.append(ret_val not in {0, 1, 2}) + iterations[bi] = n_iter + residual.append(lin_b @ x - y[bi]) + else: + try: + x, ret_val = scipy_function(lin_b, y[bi], x0=x0[bi], rtol=rtol[bi], atol=atol[bi], maxiter=max_iter[-1, bi], M=pre_op, callback=count_callback) + except TypeError: + x, ret_val = scipy_function(lin_b, y[bi], x0=x0[bi], tol=rtol[bi], atol=atol[bi], maxiter=max_iter[-1, bi], M=pre_op, callback=count_callback) # for old SciPy versions + # ret_val: 0=success, >0=not converged, <0=error + messages.append(f"code {ret_val} (SciPy {scipy_function.__name__})") + xs.append(x) + converged.append(ret_val == 0) + diverged.append(ret_val < 0 or np.any(~np.isfinite(x))) + residual.append(lin_b @ x - y[bi]) x = np.stack(xs).astype(to_numpy_dtype(dtype)) residual = np.stack(residual).astype(to_numpy_dtype(dtype)) iterations = np.asarray(iterations, np.int32) diff --git a/phiml/math/_optimize.py b/phiml/math/_optimize.py index 7749d1d..ff3b6e3 100644 --- a/phiml/math/_optimize.py +++ b/phiml/math/_optimize.py @@ -528,7 +528,7 @@ def solve_linear(f: Union[Callable[[X], Y], Tensor], * `'biCG-stab'` or `'biCG-stab(1)'`: Biconjugate gradient stabilized, first order * `'biCG-stab(2)'`, `'biCG-stab(4)'`, ...: Biconjugate gradient stabilized, second or higher order * `'scipy-direct'`: SciPy direct solve always run oh the CPU using `scipy.sparse.linalg.spsolve`. - * `'scipy-CG'`, `'scipy-GMres'`, `'scipy-biCG'`, `'scipy-biCG-stab'`, `'scipy-CGS'`, `'scipy-QMR'`, `'scipy-GCrotMK'`: SciPy iterative solvers always run oh the CPU, both in eager execution and JIT mode. + * `'scipy-CG'`, `'scipy-GMres'`, `'scipy-biCG'`, `'scipy-biCG-stab'`, `'scipy-CGS'`, `'scipy-QMR'`, `'scipy-GCrotMK'`, `'scipy-lsqr'`: SciPy iterative solvers always run oh the CPU, both in eager execution and JIT mode. For maximum performance, compile `f` using `jit_compile_linear()` beforehand. Then, an optimized representation of `f` (such as a sparse matrix) will be used to solve the linear system.