Maintain fp32 for optimizer state when offloading is enabled #1223
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Address issues with optimizer state offloading and data type conversion.
We identified two issues concerning the conversion from fp32 to fp16 for the optimizer state when enabling optimizer state offloading:
The comparison between configurations without and with optimizer state offloading was unfair because the data sizes differed, with the former using fp32 and the latter using fp16.
The presence of two modules with jit_train_step due to separate versions for fp32 and fp16 created inconsistencies.
This commit removes the fp32 to fp16 conversion, ensuring that the optimizer state retains its original data type.
We observed no memory savings when switching from f16 to f32 previously. The root cause is that the GPU memory scheduler does not distinguish between CPU memory and GPU memory. This XLA PR modifies the scheduler to exclude CPU memory and is merged so that we could reenable the CL (#1184) again.