Skip to content

Commit

Permalink
Reduce peak memory during FLUX model load (#7564)
Browse files Browse the repository at this point in the history
## Summary

Prior to this change, there were several cases where we initialized the
weights of a FLUX model before loading its state dict (and, to make
things worse, in some cases the weights were in float32). This PR fixes
a handful of these cases. (I think I found all instances for the FLUX
family of models.)

## Related Issues / Discussions

- Helps with #7563

## QA Instructions

I tested that that model loading still works and that there is no
virtual memory reservation on model initialization for the following
models:
- [x] FLUX VAE
- [x] Full T5 Encoder
- [x] Full FLUX checkpoint
- [x] GGUF FLUX checkpoint

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 16, 2025
2 parents b57aa06 + b2bb359 commit 0abb5ea
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ def _load_model(
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)

return model

Expand Down Expand Up @@ -183,7 +183,9 @@ def _load_model(
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
)

raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
Expand Down Expand Up @@ -217,17 +219,18 @@ def _load_from_singlefile(
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)

sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
return model


Expand Down Expand Up @@ -258,11 +261,11 @@ def _load_from_singlefile(
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])

# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)

# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.
Expand Down

0 comments on commit 0abb5ea

Please sign in to comment.