Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEW: Use properties of linear operators to speed up linesearch #61

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,64 @@
logger = logging.getLogger(__name__)


def line_search(f, x, d, step_length=1, step_shrink=0.5):
"""Return a new `step_length` using a backtracking line search.
def line_search_sqr(f, p0, p1, p2, step_length=1, step_shrink=0.5):
"""Perform an optimized line search for squared absolute value functions.

Starting with the following identity which converts a squared absolute
value expression into the sum of two quadratics.

```
a, b = np.random.rand(2) + 1j * np.random.rand(2)
c = np.random.rand()
abs(a + c * b)**2 == (a.real + c * b.real)**2 + (a.imag + c * b.imag)**2
```

Then, assuming the operator G is a linear operator.

sum_j |G(x_j + step * d_j)|^2 = step^2 * p2 + step * p1 + p0

p2 = sum_j |G(d_j)|^2
p1 = 2 * sum_j( G(x_j).real * G(d_j).real + G(x_j).imag * G(d_j).imag )
p0 = sum_j |G(x_j)|^2

Parameters
----------
f : function(x)
The function being optimized.
p0,p1,p2 : vectors
Temporarily vectors to avoid computing forward operators
"""
assert step_shrink > 0 and step_shrink < 1
m = 0 # Some tuning parameter for termination
# Save cache function calls instead of computing them many times
fx = f(p0)
# Decrease the step length while the step increases the cost function
while True:
fxsd = f(p0 + step_length * p1 + step_length**2 * p2)
if fxsd <= fx + step_shrink * m:
break
step_length *= step_shrink
if step_length < 1e-32:
warnings.warn("Line search failed for conjugate gradient.")
return 0, fx
return step_length, fxsd


def line_search(f, x, d, step_length=1, step_shrink=0.5, linear=None):
"""Perform a backtracking line search for a partially-linear cost-function.

For cost functions composed of a non-linear part, f, and a linear part, l,
such that the cost = f(l(x)), a backtracking line search computations may
be reduced in exchange for memory because l(x + γ * d) = l(x) + γ * l(d).
For completely non-linear functions, the linear part is just the identity
function.

Parameters
----------
f : function(linear(x))
The non-linear part of the function being optimized.
linear : function(x), optional
The linear part of the function being optimized.
x : vector
The current position.
d : vector
Expand All @@ -42,11 +93,15 @@ def line_search(f, x, d, step_length=1, step_shrink=0.5):

"""
assert step_shrink > 0 and step_shrink < 1
linear = (lambda x: x) if linear is None else linear
m = 0 # Some tuning parameter for termination
fx = f(x) # Save the result of f(x) instead of computing it many times
# Save cache function calls instead of computing them many times
lx = linear(x)
ld = linear(d)
fx = f(lx)
# Decrease the step length while the step increases the cost function
while True:
fxsd = f(x + step_length * d)
fxsd = f(lx + step_length * ld)
if fxsd <= fx + step_shrink * m:
break
step_length *= step_shrink
Expand All @@ -69,19 +124,17 @@ def direction_dy(xp, grad0, grad1, dir):
The previous search direction.

"""
return (
- grad1
+ dir * xp.linalg.norm(grad1.ravel())**2
/ (xp.sum(dir.conj() * (grad1 - grad0)) + 1e-32)
)
return (-grad1 + dir * xp.linalg.norm(grad1.ravel())**2 /
(xp.sum(dir.conj() * (grad1 - grad0)) + 1e-32))


def conjugate_gradient(
array_module,
x,
cost_function,
grad,
num_iter=1,
array_module,
x,
cost_function,
grad,
num_iter=1,
linear_function=None,
):
"""Use conjugate gradient to estimate `x`.

Expand All @@ -108,9 +161,10 @@ def conjugate_gradient(
grad0 = grad1
gamma, cost = line_search(
f=cost_function,
linear=linear_function,
x=x,
d=dir,
)
x = x + gamma * dir
logger.debug("%4d, %.3e, %.7e", (i + 1), gamma, cost)
logger.debug("step %d; length = %.3e; cost = %.6e", i, gamma, cost)
return x, cost