Skip to content

Commit

Permalink
Fix(device name): use consist way for get device (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Mar 28, 2024
1 parent f97ddc3 commit f8621ee
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 62 deletions.
8 changes: 1 addition & 7 deletions internlm/accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,13 @@ def device_name(self, device_index=None):
"""
raise NotImplementedError

def device(self, device_index=None):
"""
Return the device object.
"""
raise NotImplementedError

def set_device(self, device_index):
"""
Bind the current process to a device.
"""
raise NotImplementedError

def current_device(self):
def get_device_id(self):
"""
Return the current device index.
"""
Expand Down
8 changes: 1 addition & 7 deletions internlm/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,13 @@ def device_name(self, device_index=None):
return "cuda"
return "cuda:{}".format(device_index)

def device(self, device_index=None):
"""
Return the device object.
"""
return torch.cuda.device(device_index)

def set_device(self, device_index):
"""
Bind the current process to a device.
"""
torch.cuda.set_device(device_index)

def current_device(self):
def get_device_id(self):
"""
Return the current device index.
"""
Expand Down
8 changes: 1 addition & 7 deletions internlm/accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@ def device_name(self, device_index=None):
return "npu"
return "npu:{}".format(device_index)

def device(self, device_index=None):
"""
Return the device object.
"""
return torch.npu.device(device_index)

def set_device(self, device_index):
"""
Bind the current process to a device.
"""
torch.npu.set_device(device_index)

def current_device(self):
def get_device_id(self):
"""
Return the current device index.
"""
Expand Down
7 changes: 2 additions & 5 deletions internlm/core/communication/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
from torch import distributed as dist
from torch import nn

from internlm.accelerator import get_accelerator
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.ops.linear import ISPLinear
from internlm.model.utils import all_gather_raw, reduce_scatter_raw
from internlm.utils.common import SchedulerHook

internlm_accelerator = get_accelerator()
from internlm.utils.common import SchedulerHook, get_current_device


@dataclass
Expand All @@ -26,7 +23,7 @@ class ISPCommModelConfig:
"""

dtype: torch.dtype = torch.half
device: torch.device = internlm_accelerator.device()
device: torch.device = get_current_device()
activation_checkpointing: float = 0.0
module_shapes: Dict[str, torch.Size] = None

Expand Down
3 changes: 0 additions & 3 deletions internlm/core/gradient_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from internlm.accelerator import get_accelerator
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device

internlm_accelerator = get_accelerator()


class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
Expand Down
2 changes: 1 addition & 1 deletion internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _post_forward_hook_for_fp32(
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
if gpc.config.get("output_tf32", False) and module_is_output(sub_module):
sub_module.to(fp32_dtype)
if get_accelerator().get_accelerator_backend() == AcceleratorType.GPU:
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
2 changes: 0 additions & 2 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch.distributed as dist

from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
Expand All @@ -24,7 +23,6 @@
from .base_scheduler import BaseScheduler

logger = get_logger(__file__)
internlm_accelerator = get_accelerator()


class NonPipelineScheduler(BaseScheduler):
Expand Down
6 changes: 2 additions & 4 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from torch import nn

from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
Expand All @@ -24,15 +23,14 @@
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.common import filter_kwargs
from internlm.utils.common import filter_kwargs, get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER

MODEL_TYPE = "INTERNLM_MoE"

logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
internlm_accelerator = get_accelerator()


class PackedFlashBaseLayer1D(nn.Module):
Expand Down Expand Up @@ -456,7 +454,7 @@ def _build_generic_model_1d(num_layers, num_chunks, **kwargs):
device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default.
"""
device = internlm_accelerator.device()
device = get_current_device()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

Expand Down
3 changes: 0 additions & 3 deletions internlm/model/moe/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from torch import Tensor
from torch.nn import Module, ModuleList

from internlm.accelerator import get_accelerator
from internlm.core.context import global_context as gpc
from internlm.model.moe.experts import Experts
from internlm.utils.common import get_current_device

internlm_accelerator = get_accelerator()

if TYPE_CHECKING:
Base = Module[Tensor]
else:
Expand Down
53 changes: 37 additions & 16 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,24 +418,45 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
"""Initialize and return the profiler context manager instance."""

if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
llm_profile = torch.profiler.profile
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3}
trace_path = (
f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_"
f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}"
)
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
l2_cache=False,
)
llm_profile = torch_npu.profiler.profile(
activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU],
schedule=torch_npu.profiler.schedule(**schedule_config),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config,
)
logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!")
else:
llm_profile = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(**schedule_config),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
with_stack=True,
with_modules=True,
profile_memory=True,
)
logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!")
else:
llm_profile = DummyProfile
llm_profile = DummyProfile()

return llm_profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
),
with_stack=True,
with_modules=True,
profile_memory=True,
)
return llm_profile


@llm_timeout(func_name="record_current_batch_training_metrics")
Expand Down
2 changes: 1 addition & 1 deletion internlm/utils/gputest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def flops(batch, seqlen, headdim, nheads, time_f):
def get_gpu_temperature():
"""Get current GPU temperature."""
try:
gpu_id = internlm_accelerator.current_device()
gpu_id = internlm_accelerator.get_device_id()
except AssertionError:
gpu_id = -1

Expand Down
2 changes: 1 addition & 1 deletion internlm/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def init_tb_writer(
writer.add_text(
tag=f"mapping_{tb_log_file_name}",
text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} \
device={internlm_accelerator.current_device()}",
device={internlm_accelerator.get_device_id()}",
global_step=step_count,
)
writer.add_scaler = partial(writer.add_scalar, new_style=True)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.distributed as dist

import internlm
from internlm.accelerator import get_accelerator
from internlm.checkpoint import CheckpointManager
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
Expand All @@ -25,7 +24,6 @@
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.megatron_timers import megatron_timer as timer

internlm_accelerator = get_accelerator()
CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py")
TOTAL_STEPS = 10
LOSS_SPIKE_LIMIT = 1.5
Expand Down
3 changes: 0 additions & 3 deletions tests/test_training/train_CI.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import torch
import torch.distributed as dist

from internlm.accelerator import get_accelerator

internlm_accelerator = get_accelerator()
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)
Expand Down

0 comments on commit f8621ee

Please sign in to comment.