Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation Loss Enhancements #1900

Open
wants to merge 7 commits into
base: sd3
Choose a base branch
from

Conversation

stepfunction83
Copy link

@stepfunction83 stepfunction83 commented Jan 26, 2025

I am creating this PR to implement a number of additional enhancements to the base validation loss implementation in #1899. These include:

Low Effort

  • (Complete) Calc validation at start with --validation_at_start argument
  • (Complete) Clean up progress bars so the validation one sits neatly above the global one and is not recreated each time

Bug Fixes

  • (Complete) Correct bug for not working for Flux LoRA training if latents are not cached

Additional items for future PRs:

High Effort

  • Implement for finetuning (Flux)
  • Implement for finetuning (SD3)
  • Implement for finetuning (SD1.5)
  • Capture loss calculation state and attach to samples to reduce variance of calculation to model changes only
  • Add test set loss calculation in addition to validation set to properly track generalization error

Medium Effort

  • Add multiple noise/timestep iterations per sample
  • Add specific validation sample count in addition to current % based selection
  • Remove need to recalculate train latents every run if the only difference is the train/validation split

Low Effort

  • Add relative loss (loss / initial loss value)

@stepfunction83
Copy link
Author

stepfunction83 commented Jan 26, 2025

image

Steps is once again at the bottom and validation steps sits right above it throughout the entire training cycle.

@rockerBOO
Copy link
Contributor

I was going to be working on validation implementation for the training scripts because it requires implementing an abstraction for process_batch like train_network. We can work together for implementing those. Also there may be motivation to make abstractions of the training process because each script is done separately.

For timesteps consider the work done in #1165 which added timestep validation process. You can see the progress I made on bbf6bbd that I removed to get the PR merged. Would require a larger abstraction to get the random timestep or use the set timestamps. This timestep process would probably be good as it's on PR and on a smaller subset like LoRA training and then we can then use it in the training scripts. But because the process_batch aspect would need to be translated that may need to come before the training scripts.

Also for merging purposes it might be better to do in smaller pieces so it is easier to review. Each line item you presented could be presented and discussed before you do the work.

@rockerBOO
Copy link
Contributor

Also additionally look at #914 #1165 for more context of how things got to this point so we can be on the same page there as well.

@stepfunction83
Copy link
Author

stepfunction83 commented Jan 26, 2025

That's fair. I have a lot I want to add to this to get it where I would like. Right now, I think some of the simpler items like the bug fixes, progress bar formatting, and calculating validation at start would be good to get in sooner.

As far as training goes, I had already created a form of abstraction for the Flux finetuning loss calculation, which you see in #1898. I was going to migrate that over to this.

For the stable timesteps, I plan to implement a state argument for calculate_loss() which will effectively fix the variables for noise, noisy_input_to_model, timesteps, and sigmas after they are calculated for the first run. They will not be pre-determined, but will stored in the batch and manually set each time the function is called. I was planning to make an inner loop to perform one iteration for each noise/timestep repeat and capture the results in a list attached to the batch.

@rockerBOO
Copy link
Contributor

Yeah I think what you have right now in this PR could be good on it's own and good improvements. Loss calculation probably could be on it's own and Kohya appreciated it. And timesteps another PR would be good. Then we can get them merged faster and I can give feedback and help test.

@stepfunction83
Copy link
Author

stepfunction83 commented Jan 26, 2025

I agree with that approach and with the breakdown you suggest. I'd like to add one more change to this PR to have a fixed number of validation samples (--validation_set_fixed_size) as an additional option. With a large data set like the one I'm using (500 images), it can be a bit annoying to specify small validation sets (<20 items) as a % of the total.

After that we can probably pause there and work on some of the other items in additional PRs.

@stepfunction83
Copy link
Author

stepfunction83 commented Jan 26, 2025

Looking through the code, that actually seems like a non-trivial workload to implement. It should probably be its own thing as well. I think it's important to be able to set a fixed validation set count because there is a non-trivial performance hit from running a lot of validation steps if the dataset is large. For now, you can just calculate the right split and set it accordingly.

I think the order I'd like to tackle the other items is:

  1. State capture and replay in calculate_loss() for consistent loss calculation
  2. Multiple timestep iterations for additional variance reduction in validation loss
  3. Finetuning support
  4. Construction of test set loss in addition to validation set
  5. Additional loss metrics

@kohya-ss
Copy link
Owner

Thank you for the various suggestions!

I believe in being careful about abstracting the training process. Over-abstraction (like other trainers) can lead to poor readability. I prefer to put the main steps in their own main scripts.

Regarding the stable timesteps, I have the following thoughts based on existing discussions:

  • Timesteps should be fixed, not random, because the number of validation steps is relatively small and the timesteps can be highly biased depending on the choice of random seed.
  • Timesteps should not be uniformly distributed, and the scheduler shift should be taken into consideration.
  • For the above two points, I think it is better to take timesteps with a fixed index value (for example, 200,400,600,800) from the timesteps of the scheduler.
  • Also, if we have multiple images in the validation set, we might have to do inference on multiple timesteps for each one. This requires more validation steps, but it might be worth it.

Note that it doesn't cost much to compute the targets during training; caching them is more resource intensive.

The fixed number of validation samples is an interesting feature, but adding new functionality is not desirable if it can be replaced by a small validation split.

I apologize for our slow response time to pull requests. For better collaboration, I also think it is desirable if you could start with a smaller pull request as you wrote, so we can first discuss the implementation approach. This would help avoid situations where PRs remain open for a long time or require major changes due to differences in architectural decisions.

@stepfunction83
Copy link
Author

stepfunction83 commented Jan 27, 2025

Okay, take a look at what I implemented here. It's in the same spirit as what I implemented originally and accomplishes the goals of snapshotting the noise state and timesteps for each sample. The work to sample random timesteps is trivial and works very well for creating a stable loss curve. If you take 4 timesteps per sample, you'll get a reasonable scattering.

Keeping the state tracker in memory doesn't seem to add any significant load.

With this, the raw loss curve is much more consistent already (with no smoothing or moving averages applied):

image

I do have a lot of cleanup to do, but this is the basic proof of concept.

@kohya-ss
Copy link
Owner

Thanks for the update, but I'm planning a comprehensive update on validation loss in the next few days, so you might want to wait until then to update. Please note that major changes may be required to this PR.

@stepfunction83
Copy link
Author

Sure! That sounds good. What I put together here is pretty straightforward, so it shouldn't be too hard to reimplement whenever. I'll look out for the update.

@kohya-ss
Copy link
Owner

I'm starting to implement it here, any suggestions are welcome: #1903.

… validation loss calculation"

This reverts commit 3d5c644.
@stepfunction83
Copy link
Author

I'm starting to implement it here, any suggestions are welcome: #1903.

I reverted the bulk change that included snapshotting. The remaining commits have some good cleanup items and utility which shouldn't conflict substantially with the changes you're making in the other PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants