We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Welcome to the flux-jax wiki!
device = "gpu" if jax.get_devices("gpu") else "cpu"
Batch, Target_Length, Num_Heads, Hidden_Dims
Batch, Num_Heads, Target_Length, Hidden_Dims
x.to(device)
jax.device_put(x, device)