Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a shortfin pipeline for flux #876

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def __post_init__(self):
def from_gguf_properties(properties: dict[str, Any], **kwargs):
assert properties["general.architecture"] == "t5"
assert (
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
)

all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
Expand Down Expand Up @@ -290,10 +290,10 @@ class ClipTextConfig:
output_hidden_states: bool = False
use_return_dict: bool = True
dtype: torch.dtype = torch.float32

@staticmethod
def from_hugging_face_clip_text_model_config(
config: "transformers.CLIPTextConfig",
config: "transformers.CLIPTextConfig", # type: ignore
) -> "ClipTextConfig":
return ClipTextConfig(
vocab_size=config.vocab_size,
Expand All @@ -314,7 +314,7 @@ def from_hugging_face_clip_text_model_config(
dtype=config.torch_dtype or torch.float32,
)

def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig":
def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": # type: ignore
kwargs = self.to_properties()
kwargs["torch_dtype"] = kwargs["dtype"]
del kwargs["dtype"]
Expand Down
8 changes: 7 additions & 1 deletion sharktank/sharktank/models/clip/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
from typing import Optional, Union
import transformers
from transformers.models.clip.modeling_clip import (
Expand All @@ -18,6 +19,7 @@
from ...layers.configs import ClipTextConfig
from .clip import ClipTextModel
from iree.turbine.aot import FxProgramsBuilder, export
from sharktank.transforms.dataset import set_float_dtype


def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta:
Expand Down Expand Up @@ -50,8 +52,12 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)


def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike):
def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put a type hint for the new dtype argument?
I think they are important even if the assumption is that everyone should know what dtype should be.
They also improve code navigation as parsers could do a better job. They can determine types of resulting expressions.

dataset = clip_text_model_to_dataset(model)
if dtype:
dataset.root_theta = dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=dtype)
)
dataset.save(output_path)


Expand Down
11 changes: 8 additions & 3 deletions sharktank/sharktank/models/flux/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
from os import PathLike
import os
from pathlib import Path
Expand All @@ -14,6 +15,7 @@
from .flux import FluxModelV1, FluxParams
from ...types import Dataset
from ...utils.hf_datasets import get_dataset
from sharktank.transforms.dataset import set_float_dtype

flux_transformer_default_batch_sizes = [1]

Expand All @@ -27,11 +29,14 @@ def export_flux_transformer_model_mlir(


def export_flux_transformer_iree_parameters(
model: FluxModelV1, parameters_output_path: PathLike
model: FluxModelV1, parameters_output_path: PathLike, dtype = None
):
model.theta.rename_tensors_to_paths()
# TODO: export properties
dataset = Dataset(root_theta=model.theta, properties={})
dataset = Dataset(root_theta=model.theta, properties=model.params.to_hugging_face_properties())
if dtype:
dataset.root_theta = dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=dtype)
)
dataset.save(parameters_output_path)


Expand Down
16 changes: 15 additions & 1 deletion sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, Optional
from collections import OrderedDict
from copy import copy
from dataclasses import dataclass, asdict
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -49,6 +49,19 @@ class FluxParams:
qkv_bias: bool
guidance_embed: bool

def to_hugging_face_properties(self) -> dict[str, Any]:
hparams = {
"in_channels": self.in_channels,
"pooled_projection_dim": self.vec_in_dim,
"joint_attention_dim": self.context_in_dim,
"num_attention_heads": self.num_heads,
"num_layers": self.depth,
"num_single_layers": self.depth_single_blocks,
"attention_head_dim": sum(self.axes_dim),
"guidance_embeds": self.guidance_embed
}
return {"hparams": hparams}

@staticmethod
def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams":
p = properties["hparams"]
Expand Down Expand Up @@ -175,6 +188,7 @@ def forward(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))

vec = vec + self.vector_in(y)

txt = self.txt_in(txt)
Expand Down
9 changes: 0 additions & 9 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,6 @@ def forward(
"latent_embeds": latent_embeds,
},
)
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)
Comment on lines -77 to -85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal intentional?

sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
Expand Down
7 changes: 7 additions & 0 deletions sharktank/sharktank/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Flux text-to-image generation pipeline."""

from .flux_pipeline import FluxPipeline

__all__ = [
"FluxPipeline",
]
Loading
Loading