-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation Loss Enhancements #1900
base: sd3
Are you sure you want to change the base?
Conversation
I was going to be working on validation implementation for the training scripts because it requires implementing an abstraction for process_batch like train_network. We can work together for implementing those. Also there may be motivation to make abstractions of the training process because each script is done separately. For timesteps consider the work done in #1165 which added timestep validation process. You can see the progress I made on bbf6bbd that I removed to get the PR merged. Would require a larger abstraction to get the random timestep or use the set timestamps. This timestep process would probably be good as it's on PR and on a smaller subset like LoRA training and then we can then use it in the training scripts. But because the process_batch aspect would need to be translated that may need to come before the training scripts. Also for merging purposes it might be better to do in smaller pieces so it is easier to review. Each line item you presented could be presented and discussed before you do the work. |
That's fair. I have a lot I want to add to this to get it where I would like. Right now, I think some of the simpler items like the bug fixes, progress bar formatting, and calculating validation at start would be good to get in sooner. As far as training goes, I had already created a form of abstraction for the Flux finetuning loss calculation, which you see in #1898. I was going to migrate that over to this. For the stable timesteps, I plan to implement a |
Yeah I think what you have right now in this PR could be good on it's own and good improvements. Loss calculation probably could be on it's own and Kohya appreciated it. And timesteps another PR would be good. Then we can get them merged faster and I can give feedback and help test. |
I agree with that approach and with the breakdown you suggest. I'd like to add one more change to this PR to have a fixed number of validation samples ( After that we can probably pause there and work on some of the other items in additional PRs. |
Looking through the code, that actually seems like a non-trivial workload to implement. It should probably be its own thing as well. I think it's important to be able to set a fixed validation set count because there is a non-trivial performance hit from running a lot of validation steps if the dataset is large. For now, you can just calculate the right split and set it accordingly. I think the order I'd like to tackle the other items is:
|
Thank you for the various suggestions! I believe in being careful about abstracting the training process. Over-abstraction (like other trainers) can lead to poor readability. I prefer to put the main steps in their own main scripts. Regarding the stable timesteps, I have the following thoughts based on existing discussions:
Note that it doesn't cost much to compute the targets during training; caching them is more resource intensive. The fixed number of validation samples is an interesting feature, but adding new functionality is not desirable if it can be replaced by a small validation split. I apologize for our slow response time to pull requests. For better collaboration, I also think it is desirable if you could start with a smaller pull request as you wrote, so we can first discuss the implementation approach. This would help avoid situations where PRs remain open for a long time or require major changes due to differences in architectural decisions. |
…ion loss calculation
Okay, take a look at what I implemented here. It's in the same spirit as what I implemented originally and accomplishes the goals of snapshotting the noise state and timesteps for each sample. The work to sample random timesteps is trivial and works very well for creating a stable loss curve. If you take 4 timesteps per sample, you'll get a reasonable scattering. Keeping the state tracker in memory doesn't seem to add any significant load. With this, the raw loss curve is much more consistent already (with no smoothing or moving averages applied): I do have a lot of cleanup to do, but this is the basic proof of concept. |
Thanks for the update, but I'm planning a comprehensive update on validation loss in the next few days, so you might want to wait until then to update. Please note that major changes may be required to this PR. |
Sure! That sounds good. What I put together here is pretty straightforward, so it shouldn't be too hard to reimplement whenever. I'll look out for the update. |
I'm starting to implement it here, any suggestions are welcome: #1903. |
… validation loss calculation" This reverts commit 3d5c644.
I reverted the bulk change that included snapshotting. The remaining commits have some good cleanup items and utility which shouldn't conflict substantially with the changes you're making in the other PR. |
I am creating this PR to implement a number of additional enhancements to the base validation loss implementation in #1899. These include:
Low Effort
--validation_at_start
argumentBug Fixes
Additional items for future PRs:
High Effort
Medium Effort
Low Effort