diff --git a/internlm/accelerator/abstract_accelerator.py b/internlm/accelerator/abstract_accelerator.py index de395a0e..104a5176 100644 --- a/internlm/accelerator/abstract_accelerator.py +++ b/internlm/accelerator/abstract_accelerator.py @@ -42,19 +42,13 @@ def device_name(self, device_index=None): """ raise NotImplementedError - def device(self, device_index=None): - """ - Return the device object. - """ - raise NotImplementedError - def set_device(self, device_index): """ Bind the current process to a device. """ raise NotImplementedError - def current_device(self): + def get_device_id(self): """ Return the current device index. """ diff --git a/internlm/accelerator/cuda_accelerator.py b/internlm/accelerator/cuda_accelerator.py index ad6a4801..48a47165 100644 --- a/internlm/accelerator/cuda_accelerator.py +++ b/internlm/accelerator/cuda_accelerator.py @@ -40,19 +40,13 @@ def device_name(self, device_index=None): return "cuda" return "cuda:{}".format(device_index) - def device(self, device_index=None): - """ - Return the device object. - """ - return torch.cuda.device(device_index) - def set_device(self, device_index): """ Bind the current process to a device. """ torch.cuda.set_device(device_index) - def current_device(self): + def get_device_id(self): """ Return the current device index. """ diff --git a/internlm/accelerator/npu_accelerator.py b/internlm/accelerator/npu_accelerator.py index e90642d0..e1bd3549 100644 --- a/internlm/accelerator/npu_accelerator.py +++ b/internlm/accelerator/npu_accelerator.py @@ -39,19 +39,13 @@ def device_name(self, device_index=None): return "npu" return "npu:{}".format(device_index) - def device(self, device_index=None): - """ - Return the device object. - """ - return torch.npu.device(device_index) - def set_device(self, device_index): """ Bind the current process to a device. """ torch.npu.set_device(device_index) - def current_device(self): + def get_device_id(self): """ Return the current device index. """ diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py index 7cebf35b..b821e994 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/communication/isp.py @@ -9,14 +9,11 @@ from torch import distributed as dist from torch import nn -from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.ops.linear import ISPLinear from internlm.model.utils import all_gather_raw, reduce_scatter_raw -from internlm.utils.common import SchedulerHook - -internlm_accelerator = get_accelerator() +from internlm.utils.common import SchedulerHook, get_current_device @dataclass @@ -26,7 +23,7 @@ class ISPCommModelConfig: """ dtype: torch.dtype = torch.half - device: torch.device = internlm_accelerator.device() + device: torch.device = get_current_device() activation_checkpointing: float = 0.0 module_shapes: Dict[str, torch.Size] = None diff --git a/internlm/core/gradient_handler.py b/internlm/core/gradient_handler.py index cf208f48..c866be5b 100644 --- a/internlm/core/gradient_handler.py +++ b/internlm/core/gradient_handler.py @@ -7,12 +7,9 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.utils.common import get_current_device -internlm_accelerator = get_accelerator() - class BaseGradientHandler(ABC): """A basic helper class to handle all-reduce operations of gradients across different parallel groups diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 46ba85b0..498c8026 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -202,7 +202,7 @@ def _post_forward_hook_for_fp32( sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) if gpc.config.get("output_tf32", False) and module_is_output(sub_module): sub_module.to(fp32_dtype) - if get_accelerator().get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 64625729..b8aefe78 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -8,7 +8,6 @@ import torch import torch.distributed as dist -from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine @@ -24,7 +23,6 @@ from .base_scheduler import BaseScheduler logger = get_logger(__file__) -internlm_accelerator = get_accelerator() class NonPipelineScheduler(BaseScheduler): diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 02498f80..a9e44cee 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -7,7 +7,6 @@ import torch from torch import nn -from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_fp32_attr_to_module @@ -24,7 +23,7 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs +from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger from internlm.utils.registry import MODEL_INITIALIZER @@ -32,7 +31,6 @@ logger = get_logger(__file__) RMSNorm = try_import_RMSNorm() -internlm_accelerator = get_accelerator() class PackedFlashBaseLayer1D(nn.Module): @@ -456,7 +454,7 @@ def _build_generic_model_1d(num_layers, num_chunks, **kwargs): device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. """ - device = internlm_accelerator.device() + device = get_current_device() pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/internlm/model/moe/base_layer.py b/internlm/model/moe/base_layer.py index b36429f8..48a4d857 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/moe/base_layer.py @@ -4,13 +4,10 @@ from torch import Tensor from torch.nn import Module, ModuleList -from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.model.moe.experts import Experts from internlm.utils.common import get_current_device -internlm_accelerator = get_accelerator() - if TYPE_CHECKING: Base = Module[Tensor] else: diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 21c93f58..d91070fc 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -418,24 +418,45 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): """Initialize and return the profiler context manager instance.""" if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - llm_profile = torch.profiler.profile - logger.info(f"Do profiling in rank {gpc.get_global_rank()}!") + schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3} + trace_path = ( + f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_" + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}" + ) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + ) + llm_profile = torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(**schedule_config), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path), + record_shapes=True, + profile_memory=True, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) + logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!") + else: + llm_profile = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(**schedule_config), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), + with_stack=True, + with_modules=True, + profile_memory=True, + ) + logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!") else: - llm_profile = DummyProfile + llm_profile = DummyProfile() - return llm_profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_" - + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" - + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" - + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", - ), - with_stack=True, - with_modules=True, - profile_memory=True, - ) + return llm_profile @llm_timeout(func_name="record_current_batch_training_metrics") diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 900d1231..d6be6359 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -90,7 +90,7 @@ def flops(batch, seqlen, headdim, nheads, time_f): def get_gpu_temperature(): """Get current GPU temperature.""" try: - gpu_id = internlm_accelerator.current_device() + gpu_id = internlm_accelerator.get_device_id() except AssertionError: gpu_id = -1 diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 8dd77980..7abb8ddd 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -82,7 +82,7 @@ def init_tb_writer( writer.add_text( tag=f"mapping_{tb_log_file_name}", text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} \ - device={internlm_accelerator.current_device()}", + device={internlm_accelerator.get_device_id()}", global_step=step_count, ) writer.add_scaler = partial(writer.add_scalar, new_style=True) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 13731db2..d70a2448 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -5,7 +5,6 @@ import torch.distributed as dist import internlm -from internlm.accelerator import get_accelerator from internlm.checkpoint import CheckpointManager from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -25,7 +24,6 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer -internlm_accelerator = get_accelerator() CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py") TOTAL_STEPS = 10 LOSS_SPIKE_LIMIT = 1.5 diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index e982dcae..098bb8e0 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -12,9 +12,6 @@ import torch import torch.distributed as dist -from internlm.accelerator import get_accelerator - -internlm_accelerator = get_accelerator() script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(script_dir, "../../")) sys.path.append(project_root)