Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convergence criteria #137

Closed
jgallowa07 opened this issue Feb 28, 2024 · 5 comments
Closed

Convergence criteria #137

jgallowa07 opened this issue Feb 28, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@jgallowa07
Copy link
Member

This issue is to note that jaxopt the package we use for optimizing our model, is merging into optax. While this poses no problem for the software as it stands it would certainly be desirable to make the switch once they merge in ProximalGradient.

This would solve current problems with the way we do training steps just to get a convergence line. This leads to non-deterministic results (i.e. 1000 iterations for one step != 100 iterations for 10 steps). and future development, more generally.

@jgallowa07 jgallowa07 added the enhancement New feature or request label Feb 28, 2024
@wsdewitt
Copy link
Contributor

wsdewitt commented Feb 28, 2024

The current approach doesn't deal with line search in a consistent way, because each call to run will initialize the step size. I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run. It would be preferable for our API to interface with the update method of ProximalGradient, rather than the run method, so we can record loss trajectories in a consistent way. We could define our own run command that iterates calls to update until state.error = tol or state.iteration = maxiter, and outputs loss trajectory data from each iterate.

@jgallowa07
Copy link
Member Author

I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run

This is great - yes that makes sense to me. luckily we're actually deterministic otherwise.

We could define our own run command that iterates calls to update until state.error = tol or state.iteration = maxiter, and outputs loss trajectory data from each iterate.

This is amazing. It'll be until after re submission that I'll be able to do this - would happily assist if you had the bandwidth to PR.

@jgallowa07
Copy link
Member Author

I'll just note that this update will probably help with some issues I seem to be running into with the spike analysis.

I suspect this has to do with the tolerance, learning rate, memory and other things that may easier to control with a custom update() loop - but It seems the model fits are not so robust after many training iterations. For context, we run the spike models for 30 independent rounds of 1000 iterations each (30K, total). These have the default tolerance set to 1e-4.

Screenshot from 2024-03-05 06-44-15

However, when we run those same models a single round of 100K iterations (everything else the same),

Screenshot from 2024-03-05 06-57-43

we can see some of the models have certainly over-fit to their data. I wonder how exactly the tolerance works with penalties? could it be the case that penalties added to the total cost is effecting the potential for the model to quit early?

This could also be related to #133 - as a ridge does seem to again, stabilize things. More testing will be needed here.

@jgallowa07 jgallowa07 changed the title jaxopt -> optax Convergence criteria Mar 13, 2024
@jgallowa07
Copy link
Member Author

Relevant to this issue, it should be noted that the primary difference between single, and multi-step models optimization is the FISTA acceleration. In the single step models, the learning rate is reset at each step. When acceleration is turned off, these two approaches yield identical results.

@jgallowa07
Copy link
Member Author

We still need to add some sort of check on the convergence. To do this, we'll add a state property to the Model object. State, from jaxopt gives the following properties:

class ProxGradState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int
  stepsize: float
  error: float
  aux: Optional[Any] = None
  velocity: Optional[Any] = None
  t: float = 1.0

I think what we want to do is simply check if the iter_num is less than the max iterations requested .... This will tell us if the condition has been met and the model exited upon meeting the specified tolerance threshold

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants