-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Normalize the T5 model identifiers so that a FLUX T5 or an SD3 T5 mod…
…el can be used interchangeably.
- Loading branch information
1 parent
edcdff4
commit b301785
Showing
5 changed files
with
48 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from invokeai.app.invocations.model import ModelIdentifierField | ||
from invokeai.backend.model_manager.config import BaseModelType, SubModelType | ||
|
||
|
||
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField: | ||
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX | ||
or SD3 models can be used interchangeably. | ||
""" | ||
if model_identifier.base == BaseModelType.Any: | ||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) | ||
elif model_identifier.base == BaseModelType.StableDiffusion3: | ||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) | ||
else: | ||
raise ValueError(f"Unsupported model base: {model_identifier.base}") | ||
|
||
|
||
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField: | ||
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX | ||
or SD3 models can be used interchangeably. | ||
""" | ||
if model_identifier.base == BaseModelType.Any: | ||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) | ||
elif model_identifier.base == BaseModelType.StableDiffusion3: | ||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) | ||
else: | ||
raise ValueError(f"Unsupported model base: {model_identifier.base}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters