diff --git a/demo_gr.py b/demo_gr.py index 3b4d022b..4ca8667f 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -162,7 +162,7 @@ def generate_image( def create_demo( - model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False + model_name: str, device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', offload: bool = False ): generator = FluxGenerator(model_name, device, offload) is_schnell = model_name == "flux-schnell" @@ -237,7 +237,7 @@ def update_img2img(do_img2img): "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name" ) parser.add_argument( - "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use" + "--device", type=str, default='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', help="Device to use" ) parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") diff --git a/demo_st.py b/demo_st.py index 74b23eb0..4afcdaa0 100644 --- a/demo_st.py +++ b/demo_st.py @@ -55,7 +55,7 @@ def get_image() -> torch.Tensor | None: @torch.inference_mode() def main( - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', offload: bool = False, output_dir: str = "output", ): diff --git a/demo_st_fill.py b/demo_st_fill.py index ddba6688..c4423755 100644 --- a/demo_st_fill.py +++ b/demo_st_fill.py @@ -138,7 +138,7 @@ def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image: @torch.inference_mode() def main( - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', offload: bool = False, output_dir: str = "output", ): diff --git a/src/flux/cli_control.py b/src/flux/cli_control.py index cd83c89e..51b5859d 100644 --- a/src/flux/cli_control.py +++ b/src/flux/cli_control.py @@ -165,7 +165,7 @@ def main( height: int = 1024, seed: int | None = None, prompt: str = "a robot made out of gold", - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', num_steps: int = 50, loop: bool = False, guidance: float | None = None, diff --git a/src/flux/cli_fill.py b/src/flux/cli_fill.py index 415c0420..bfc4925f 100644 --- a/src/flux/cli_fill.py +++ b/src/flux/cli_fill.py @@ -175,7 +175,7 @@ def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | No def main( seed: int | None = None, prompt: str = "a white paper cup", - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', num_steps: int = 50, loop: bool = False, guidance: float = 30.0, diff --git a/src/flux/cli_redux.py b/src/flux/cli_redux.py index 6c03435a..fd002986 100644 --- a/src/flux/cli_redux.py +++ b/src/flux/cli_redux.py @@ -134,7 +134,7 @@ def main( width: int = 1360, height: int = 768, seed: int | None = None, - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', num_steps: int | None = None, loop: bool = False, guidance: float = 2.5,