Skip to content

Commit

Permalink
Add trt support for BF16 (#195)
Browse files Browse the repository at this point in the history
* fix interface of `get_sample_input`

* save configuration parameters

* ae wrapper implemented

* fix import

* add AEWrapper step

* from set_model_to_dtype to prepare_model

* fix eval mode during inference

* fix clip onnx export. Now it trace ony the needed outputs

* fix t5 wrapper

* reorder input name flux

* fix flux input format for text_ids and guidance

* fix Flux imports and scale of inputs to prevent nan
added `"latent": {0: "B", 1: "latent_dim"}` as additional dynamic axes

* add torch inference while tracing

* fix casting problem in onnx trace

* solve optimization problem by removing cleanup steps

* rename to notes

* prevent nan due to large inputs

* provide base implementation of `get_model`

* format

* add trt export step

* add engine class for trt build

* add `get_input_profile` and `get_minmax_dims` abstract methods

* add `build_strongly_typed` attributed

* implement `get_minmax_dims` and `get_input_profile`

* remove `static_shape` from `get_sample_input`

* remove static sharpe and batch flags

* add typing

* remove static shape and batch flags

* offload to cpu

* enable device offloading while tracing

* check cuda is avaiable while building engines

* clip trt engine build

* add pinned transformer dependency

* fix nan with onnx and trt when executed on CUDA

* AE need to be traced in TF32 not FP16

* add `get_shape_dict` abstract method and  device as a property

* AE should be traced in TF32

* AE explicitly on TF32 and reactivate full pipeline

* add input provile to flux to enable trt engine build

* format and add input_profile to t5 for TRT build

* add `TransformersModelWrapper`

* add TransformersModelWrapper support

* add `get_shape_dict` interface

* add TransformersModelWrapper support

* add shape_dict interface

* t5 in TF32 for numerical reasons

* remove unused options

* remove unused code

* add `get_shape_dict`

* remove custom optimization

* add garbage collector

* return error

* create wrapper specific to Onnx export operatio

* user OnnxWrapper

* create base wrapper for trt engines

* moved to engine package

* moved to engine package

* forbit relative import of trt-builder

* remove wrapper and create BaseExporter or BaseEngine

* models not stored in builder class

* _prepare_model_configs as pure function

* _get_onnx_exporters as a private method to get onnx exporters

* remove unused dependencies

* from onnxwrapper to onnxengine

* trt engine class

* add `calculate_max_device_memory` to TRTBuilder

* `get_shape_dict` moved to trt-engine interface

* add common inference code

* autoencder inference wrapper

* add requirements.txt

* support guidance for ev model

* ad support for trt based on evn variables

* format flux

* remove stream from constructor

* fix iterate over onnx-exporters

* flux is not strongly type

* move back for numerical stability

* add logging

* fix dtype casting for bfloat16

* fix default value

* add version before merge

* hacky get it building the engines

* requirements.txt

* adding a seperate _engine.py file for all the flux, t5 and clip engine

* boilerroom and plating. getting parameters handle into setting up the trt engines

* remove _version.py from git

* create base mixin class to share parameters

* clipmixin parameters

* remove parameters as are part of mixin class

* clip engine and exporter use common mixin for managing parameters

* use mixin cass to build engine from exporter

* ae-mixin for shared parameters

* flux exporter and engine unified by mixin class

* formatting

* add common `get_latent_dims` method

* add `get_latent_dims` common method

* T5 based on mixin class

* build strongly typed flux

* enable load with shared device memory

* remove boilderpart code to create engines

* add tokenizer to trt engine

* use static shape for reduce memory consumption

* implemnet tokenizer into t5 engine

* mix max_batch size to 8

* add licence

* add licence

* enable trt runtime tracking

* add static-batch and static-shape options

* add cuda steam to load method

* add inference code

* add inference code

* enable static shape

* add `static_shape` option to reduce memory and `_build_engine` as staticmethod

* add `should_be_dtype` filed to handle output type conversion

* from trtbuilder to trt_manager

* from TRTBuilder to TRTManager

* AE engine interface

* `trt_to_torch_dtype_dict` as property

* clip engine inference

* implement flux trt engine inference process

* add scale_factor and shift_factor

* removed `should_be_dtype`

* removed `should_be_dtype`

* remove `should_be_dtype` from t5

* add scale and shift factor

* `max_batch` to 8

* implement `TRTManager`

* from ae to vae to match DD

* remove autocast

* `pooled_embeddings` to match DD naming for clip

* rename `flux` to `transformer` engine

* from flux to transformer mixin

* from flux to transforemer exporter

* fix trtmanger with naming

* fix inputs names and dimentions. Nota that `img_ids` and `txt_ids` are without batch dim

* fix shape of inputs according to `text_maxlen` and batch_size

* reduce max_batch

* fix stage naming

* add support for DD model

* add support for DD models

* fix dtype configuration

* fix enginge dtype

* trensformers inference interface to match DD

* vae inference script dtype mapping

* remove dtype checks as multiples can be actives

* by default tf32 always active

* fix trt enginges names

* add wrapper for fluxmodel to match DD onnx configuration

* add autocast back in to match DD setup

* fix dependencies for trt support

* support trt

* add explicit kwargs

* vscode setup

* add setup instructions for trt

* `trt` dependencies not part of `all`

* from onnx_exporter to exporter

* hide onnx parameters

* from onnx-exporter to exporter

* exporter responsible to build trt engine and onnx exportr

* hide onnx parameter

* remove build function from engine class

* remove unused import

* remove space

* manage t5 and vae separately

* disable autocast

* stronglytyped t5

* fix input type and max image size

* max image size

* T5 not strongly typed

* testing

* fix torch sycronize problem

* don't build already present engines

* remove torch save

* removed onnx dependencies

* add trt dependencies

* remove trt dependencies from toml

* rename requirements and fix readme

* remove unused files

* fix import format

* remove comments

* add gitignore

* reset dependencies

* add hidden setup files

* solve ruff check

* fix imports with rufs

* run ruff formatter

* update gitignore

* simplify dependencies

* remove gitignore

* add cli formatting

* fix import orders

* simplify dependencies

* solve vae quality issue

* fix ruff format

* fix merge changes

* format and sort src/flux/cli

* fix merge conflicts

* add trt import

* add static shape support (not completed)

* remove fp8 support

* add static shape

* add static shape to t5

* add static shape to transformer

* remove model opt code

* enable offloading with trt engines

* add `stream` as part of `init_runtime`

* enable offloading

* `allocate_buffers` moved to call

* formatting

* add capability to compute `img_dim`

* enable dynamic or static-shape

* split base-engine and engine class

* clip as engine

* t5 as engine

* transformer as engine

* VAEDEcoder as engine and VAEEngine as BaseEngine

* from vae to vae_decoder, vae_encoder and vae

* use `set_stream` and fix activate call

* fix import and remove stages in TRTManager

* from BaseEngine to BaseEngine and Engine

* fix imports

* add trt support to cli_controlnet

* add vae_encoder to support controlnet

* refactor vae engine to use load() and activate() functions

* implement vae_encoder_exporter. Not tested

* fix imports

* add static_batch and static_shape to cli.py as additional option ?

* update dependencies

* revert formatting

* from Self to Any to be compatible with pytorch 3.10

* from `vae_decoder` to `vae` for compatibility with oss engines

* missing torch import

* add `scale_factor` and `shift_factor` to VAE-encoder

* add check if vae is traced

* offload while tracing

* default `text_maxlen` set to dev size instead of schnell

* remove line

* add warnign when text_maxlen is not read from t5

* fix imports

---------

Co-authored-by: scavallari <[email protected]>
Co-authored-by: ahohl <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2025
1 parent d06f828 commit 4343061
Show file tree
Hide file tree
Showing 27 changed files with 2,725 additions and 7 deletions.
66 changes: 65 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@ source .venv/bin/activate
pip install -e ".[all]"
```

## Local installation with TRT support

```bash
docker pull nvcr.io/nvidia/pytorch:24.10-py3
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
docker run --rm -it --gpus all -v $PWD:/workspace/flux nvcr.io/nvidia/pytorch:24.10-py3 /bin/bash
# inside container
cd /workspace/flux
pip install -e ".[all]"
pip install -r trt_requirements.txt
```

### Models

We are offering an extensive suite of models. For more information about the individual models, please refer to the link under **Usage**.
We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**.

| Name | Usage | HuggingFace repo | License |
| --------------------------- | ---------------------------------------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- |
Expand All @@ -42,6 +55,57 @@ We are offering an extensive suite of models. For more information about the ind

The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.

We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:

```bash
python demo_gr.py --name flux-schnell --device cuda
```

Options:

- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo

To run the demo with the dev model and create a public link:

```bash
python demo_gr.py --name flux-dev --share
```

## Diffusers integration

`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:

```shell
pip install git+https://github.com/huggingface/diffusers.git
```

Then you can use `FluxPipeline` to run the model

```python
import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=4, #use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```

To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation

## API usage

Our API offers access to our models. It is documented here:
Expand Down
3 changes: 0 additions & 3 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
Expand All @@ -24,7 +23,6 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
Expand Down Expand Up @@ -153,7 +151,6 @@ def generate_image(
exif_data[ExifTags.Base.Model] = self.model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt

img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)

return img, str(opts.seed), filename, None
Expand Down
67 changes: 65 additions & 2 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.trt.trt_manager import TRTManager
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

NSFW_THRESHOLD = 0.85
Expand All @@ -25,7 +27,9 @@ class SamplingOptions:


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
user_question = (
"Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
)
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
Expand Down Expand Up @@ -108,6 +112,8 @@ def main(
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
Expand All @@ -126,6 +132,8 @@ def main(
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
kwargs: additional arguments for TensorRT support
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -158,6 +166,57 @@ def main(
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)

if trt:
# offload to CPU to save memory
ae = ae.cpu()
model = model.cpu()
clip = clip.cpu()
t5 = t5.cpu()

torch.cuda.empty_cache()

trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
)
ae.decoder.params = ae.params
engines = trt_ctx_manager.load_engines(
models={
"clip": clip,
"transformer": model,
"t5": t5,
"vae": ae.decoder,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
)

torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()

calculate_max_device_memory = trt_ctx_manager.calculate_max_device_memory(engines)
_, shared_device_memory = cudart.cudaMalloc(calculate_max_device_memory)

for _, engine in engines.items():
engine.activate(device=torch_device, device_memory=shared_device_memory)

ae = engines["vae"]
model = engines["transformer"]
clip = engines["clip"]
t5 = engines["t5"]

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
Expand Down Expand Up @@ -192,7 +251,9 @@ def main(
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
timesteps = get_schedule(
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
)

# offload TEs to CPU, load model to gpu
if offload:
Expand Down Expand Up @@ -229,6 +290,8 @@ def main(
else:
opts = None

if trt:
trt_ctx_manager.stop_runtime()

def app():
Fire(main)
Expand Down
50 changes: 50 additions & 0 deletions src/flux/cli_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.trt.trt_manager import TRTManager
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image


Expand Down Expand Up @@ -174,6 +176,8 @@ def main(
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
lora_scale: float | None = 0.85,
trt: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
Expand All @@ -192,6 +196,7 @@ def main(
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -234,6 +239,7 @@ def main(

# set lora scale
if "lora" in name and lora_scale is not None:
assert not trt, "TRT does not support LORA yet"
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(lora_scale)
Expand All @@ -245,6 +251,50 @@ def main(
else:
raise NotImplementedError()

if trt:
trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
)
ae.decoder.params = ae.params
ae.encoder.params = ae.params
engines = trt_ctx_manager.load_engines(
models={
"clip": clip.cpu(),
"transformer": model.cpu(),
"t5": t5.cpu(),
"vae": ae.decoder.cpu(),
"vae_encoder": ae.encoder.cpu(),
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
)
torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()

calculate_max_device_memory = trt_ctx_manager.calculate_max_device_memory(engines)
_, shared_device_memory = cudart.cudaMalloc(calculate_max_device_memory)

for _, engine in engines.items():
engine.activate(device=torch_device, device_memory=shared_device_memory)

ae = engines["vae"]
model = engines["transformer"]
clip = engines["clip"]
t5 = engines["t5"]

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
Expand Down
2 changes: 1 addition & 1 deletion src/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:

def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
Expand Down
6 changes: 6 additions & 0 deletions src/flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor:
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype

# z to block_in
h = self.conv_in(z)

Expand All @@ -243,6 +246,8 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
Expand Down Expand Up @@ -277,6 +282,7 @@ def forward(self, z: Tensor) -> Tensor:
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
Expand Down
Empty file added src/flux/trt/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions src/flux/trt/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flux.trt.engine.base_engine import BaseEngine, Engine
from flux.trt.engine.clip_engine import CLIPEngine
from flux.trt.engine.t5_engine import T5Engine
from flux.trt.engine.transformer_engine import TransformerEngine
from flux.trt.engine.vae_engine import VAEEngine, VAEDecoder, VAEEncoder

__all__ = [

Check failure on line 23 in src/flux/trt/engine/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

src/flux/trt/engine/__init__.py:17:1: I001 Import block is un-sorted or un-formatted
"BaseEngine",
"Engine",
"CLIPEngine",
"TransformerEngine",
"T5Engine",
"VAEEngine",
"VAEDecoder",
"VAEEncoder",
]
Loading

0 comments on commit 4343061

Please sign in to comment.