Skip to content

Commit

Permalink
Normalize the T5 model identifiers so that a FLUX T5 or an SD3 T5 mod…
Browse files Browse the repository at this point in the history
…el can be used interchangeably.
  • Loading branch information
RyanJDick authored and psychedelicious committed Jan 15, 2025
1 parent edcdff4 commit b301785
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 16 deletions.
8 changes: 6 additions & 2 deletions invokeai/app/invocations/flux_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
Expand Down Expand Up @@ -74,8 +78,8 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})

tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)

transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterator, Literal, Optional, Tuple

import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
Expand Down Expand Up @@ -76,7 +76,7 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

Expand Down
16 changes: 6 additions & 10 deletions invokeai/app/invocations/sd3_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType


Expand Down Expand Up @@ -88,16 +92,8 @@ def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)

return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
Expand Down
26 changes: 26 additions & 0 deletions invokeai/app/util/t5_model_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType


def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")


def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")
10 changes: 8 additions & 2 deletions invokeai/backend/flux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Initially pulled from https://github.com/black-forest-labs/flux

from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast

from invokeai.backend.util.devices import TorchDevice


class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(
self,
encoder: PreTrainedModel,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
is_clip: bool,
max_length: int,
):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
Expand Down

0 comments on commit b301785

Please sign in to comment.