diff --git a/src/tike/opt.py b/src/tike/opt.py index ac634366..f9094394 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -13,13 +13,64 @@ logger = logging.getLogger(__name__) -def line_search(f, x, d, step_length=1, step_shrink=0.5): - """Return a new `step_length` using a backtracking line search. +def line_search_sqr(f, p0, p1, p2, step_length=1, step_shrink=0.5): + """Perform an optimized line search for squared absolute value functions. + + Starting with the following identity which converts a squared absolute + value expression into the sum of two quadratics. + + ``` + a, b = np.random.rand(2) + 1j * np.random.rand(2) + c = np.random.rand() + abs(a + c * b)**2 == (a.real + c * b.real)**2 + (a.imag + c * b.imag)**2 + ``` + + Then, assuming the operator G is a linear operator. + + sum_j |G(x_j + step * d_j)|^2 = step^2 * p2 + step * p1 + p0 + + p2 = sum_j |G(d_j)|^2 + p1 = 2 * sum_j( G(x_j).real * G(d_j).real + G(x_j).imag * G(d_j).imag ) + p0 = sum_j |G(x_j)|^2 Parameters ---------- f : function(x) The function being optimized. + p0,p1,p2 : vectors + Temporarily vectors to avoid computing forward operators + """ + assert step_shrink > 0 and step_shrink < 1 + m = 0 # Some tuning parameter for termination + # Save cache function calls instead of computing them many times + fx = f(p0) + # Decrease the step length while the step increases the cost function + while True: + fxsd = f(p0 + step_length * p1 + step_length**2 * p2) + if fxsd <= fx + step_shrink * m: + break + step_length *= step_shrink + if step_length < 1e-32: + warnings.warn("Line search failed for conjugate gradient.") + return 0, fx + return step_length, fxsd + + +def line_search(f, x, d, step_length=1, step_shrink=0.5, linear=None): + """Perform a backtracking line search for a partially-linear cost-function. + + For cost functions composed of a non-linear part, f, and a linear part, l, + such that the cost = f(l(x)), a backtracking line search computations may + be reduced in exchange for memory because l(x + γ * d) = l(x) + γ * l(d). + For completely non-linear functions, the linear part is just the identity + function. + + Parameters + ---------- + f : function(linear(x)) + The non-linear part of the function being optimized. + linear : function(x), optional + The linear part of the function being optimized. x : vector The current position. d : vector @@ -42,11 +93,15 @@ def line_search(f, x, d, step_length=1, step_shrink=0.5): """ assert step_shrink > 0 and step_shrink < 1 + linear = (lambda x: x) if linear is None else linear m = 0 # Some tuning parameter for termination - fx = f(x) # Save the result of f(x) instead of computing it many times + # Save cache function calls instead of computing them many times + lx = linear(x) + ld = linear(d) + fx = f(lx) # Decrease the step length while the step increases the cost function while True: - fxsd = f(x + step_length * d) + fxsd = f(lx + step_length * ld) if fxsd <= fx + step_shrink * m: break step_length *= step_shrink @@ -69,19 +124,17 @@ def direction_dy(xp, grad0, grad1, dir): The previous search direction. """ - return ( - - grad1 - + dir * xp.linalg.norm(grad1.ravel())**2 - / (xp.sum(dir.conj() * (grad1 - grad0)) + 1e-32) - ) + return (-grad1 + dir * xp.linalg.norm(grad1.ravel())**2 / + (xp.sum(dir.conj() * (grad1 - grad0)) + 1e-32)) def conjugate_gradient( - array_module, - x, - cost_function, - grad, - num_iter=1, + array_module, + x, + cost_function, + grad, + num_iter=1, + linear_function=None, ): """Use conjugate gradient to estimate `x`. @@ -108,9 +161,10 @@ def conjugate_gradient( grad0 = grad1 gamma, cost = line_search( f=cost_function, + linear=linear_function, x=x, d=dir, ) x = x + gamma * dir - logger.debug("%4d, %.3e, %.7e", (i + 1), gamma, cost) + logger.debug("step %d; length = %.3e; cost = %.6e", i, gamma, cost) return x, cost