Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a shortfin pipeline for flux #876

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Add a shortfin pipeline for flux #876

wants to merge 17 commits into from

Conversation

KyleHerndon
Copy link
Contributor

No description provided.

Copy link
Contributor

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not able to review all the code. I will continue later.

@@ -50,8 +52,12 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)


def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike):
def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put a type hint for the new dtype argument?
I think they are important even if the assumption is that everyone should know what dtype should be.
They also improve code navigation as parsers could do a better job. They can determine types of resulting expressions.

Comment on lines -77 to -85
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal intentional?

Comment on lines +52 to +54
# t5_dataset.root_theta = t5_dataset.root_theta.transform(
# functools.partial(set_float_dtype, dtype=dtype)
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be commented out.


self._rng = torch.Generator(device="cpu")

def _get_noise(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is available already. We should have just one and reuse it.

Comment on lines +64 to +66
# clip_dataset.root_theta = clip_dataset.root_theta.transform(
# functools.partial(set_float_dtype, dtype=dtype)
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be commented out.

bs, c, h, w = img.shape

# Prepare image and position IDs
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we are not going to export the pipeline itself, but are einops functions like rearrange exportable to MLIR?

atol: Optional[float] = None,
rtol: Optional[float] = None,
):
"""Compare pipeline outputs between different dtypes."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention that this compares against the hugging face implementation.
Also the tests below could mention that in the name.

}

ARTIFACT_VERSION = "12032024"
SDXL_BUCKET = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDXL_BUCKET -> FLUX_BUCKET

SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/"
)
SDXL_WEIGHTS_BUCKET = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDXL -> FLUX

sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "12032024"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This date seems odd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants