diff --git a/src/accelerate/launchers.py b/src/accelerate/launchers.py index 3aa0ca22b28..3b8dc5aef23 100644 --- a/src/accelerate/launchers.py +++ b/src/accelerate/launchers.py @@ -145,7 +145,7 @@ def train(*args): if num_processes is None: num_processes = 8 - launcher = PrepareForLaunch(function, distributed_type="TPU") + launcher = PrepareForLaunch(function, distributed_type="XLA") print(f"Launching a training on {num_processes} TPU cores.") xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork") elif in_colab and get_gpu_info()[1] < 2: diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 921b86f3b9a..85e746dfe6f 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -503,30 +503,6 @@ def build(self) -> torch.profiler.profile: ) -class DeprecatedFieldDescriptor: - """ - Descriptor for deprecated fields in an enum class. - - Args: - field_name (`str`): - The name of the deprecated field. - replaced_with (`str`): - The name of the field that replaces the deprecated one. - """ - - def __init__(self, field_name, replaced_with): - self.field_name = field_name - self.replaced_with = replaced_with - - def __get__(self, instance, owner): - warnings.warn( - f"The `{self.field_name}` of `{owner}` is deprecated and will be removed in v1.0.0. " - f"Please use the `{self.replaced_with}` instead.", - FutureWarning, - ) - return getattr(owner, self.replaced_with) - - class DistributedType(str, enum.Enum): """ Represents a type of distributed environment. @@ -556,7 +532,6 @@ class DistributedType(str, enum.Enum): FSDP = "FSDP" XLA = "XLA" MEGATRON_LM = "MEGATRON_LM" - TPU = DeprecatedFieldDescriptor("TPU", "XLA") class SageMakerDistributedType(str, enum.Enum):