Skip to content
Saurav Maheshkar edited this page Sep 13, 2024 · 5 revisions

Welcome to the flux-jax wiki!

Today I learnt

device = "gpu" if jax.get_devices("gpu") else "cpu"

Jax dot product attention Torch scaled dot product attention
Batch, Target_Length, Num_Heads, Hidden_Dims Batch, Num_Heads, Target_Length, Hidden_Dims
Torch JAX
x.to(device) jax.device_put(x, device)
Clone this wiki locally