From b301785dc8244679873f633cae981dc0c91b15d9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 9 Jan 2025 22:56:02 +0000 Subject: [PATCH] Normalize the T5 model identifiers so that a FLUX T5 or an SD3 T5 model can be used interchangeably. --- invokeai/app/invocations/flux_model_loader.py | 8 ++++-- invokeai/app/invocations/flux_text_encoder.py | 4 +-- invokeai/app/invocations/sd3_model_loader.py | 16 +++++------- invokeai/app/util/t5_model_identifier.py | 26 +++++++++++++++++++ invokeai/backend/flux/modules/conditioner.py | 10 +++++-- 5 files changed, 48 insertions(+), 16 deletions(-) create mode 100644 invokeai/app/util/t5_model_identifier.py diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index ab2d69aa02b..884b01a9805 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -10,6 +10,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.t5_model_identifier import ( + preprocess_t5_encoder_model_identifier, + preprocess_t5_tokenizer_model_identifier, +) from invokeai.backend.flux.util import max_seq_lengths from invokeai.backend.model_manager.config import ( CheckpointConfigBase, @@ -74,8 +78,8 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) - tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) - t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model) + t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model) transformer_config = context.models.get_config(transformer) assert isinstance(transformer_config, CheckpointConfigBase) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 3c49b6287b1..74c293d0c09 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -2,7 +2,7 @@ from typing import Iterator, Literal, Optional, Tuple import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( @@ -76,7 +76,7 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor: context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer, ): assert isinstance(t5_text_encoder, T5EncoderModel) - assert isinstance(t5_tokenizer, T5Tokenizer) + assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast)) t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) diff --git a/invokeai/app/invocations/sd3_model_loader.py b/invokeai/app/invocations/sd3_model_loader.py index 6b2d03ef3d9..b7e23b5750f 100644 --- a/invokeai/app/invocations/sd3_model_loader.py +++ b/invokeai/app/invocations/sd3_model_loader.py @@ -10,6 +10,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.t5_model_identifier import ( + preprocess_t5_encoder_model_identifier, + preprocess_t5_tokenizer_model_identifier, +) from invokeai.backend.model_manager.config import SubModelType @@ -88,16 +92,8 @@ def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput: if self.clip_g_model else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) ) - tokenizer_t5 = ( - self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) - if self.t5_encoder_model - else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) - ) - t5_encoder = ( - self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) - if self.t5_encoder_model - else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) - ) + tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model) + t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model) return Sd3ModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), diff --git a/invokeai/app/util/t5_model_identifier.py b/invokeai/app/util/t5_model_identifier.py new file mode 100644 index 00000000000..eb3a84aee52 --- /dev/null +++ b/invokeai/app/util/t5_model_identifier.py @@ -0,0 +1,26 @@ +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.backend.model_manager.config import BaseModelType, SubModelType + + +def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField: + """A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX + or SD3 models can be used interchangeably. + """ + if model_identifier.base == BaseModelType.Any: + return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + elif model_identifier.base == BaseModelType.StableDiffusion3: + return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) + else: + raise ValueError(f"Unsupported model base: {model_identifier.base}") + + +def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField: + """A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX + or SD3 models can be used interchangeably. + """ + if model_identifier.base == BaseModelType.Any: + return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + elif model_identifier.base == BaseModelType.StableDiffusion3: + return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) + else: + raise ValueError(f"Unsupported model base: {model_identifier.base}") diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index c03e877e2db..ffbbbf20dd7 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -1,13 +1,19 @@ # Initially pulled from https://github.com/black-forest-labs/flux from torch import Tensor, nn -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast from invokeai.backend.util.devices import TorchDevice class HFEncoder(nn.Module): - def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): + def __init__( + self, + encoder: PreTrainedModel, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + is_clip: bool, + max_length: int, + ): super().__init__() self.max_length = max_length self.is_clip = is_clip