Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added veDeviceMesh #32

Merged
merged 6 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions python/example/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@

MackZackA marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group, get_rank

from model import GPTConfig, GPT
from vescale.devicemesh_api.device_mesh_api import veDeviceMesh

from vescale.dtensor.device_mesh import init_device_mesh
from vescale import distribute_tensor
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.placement_types import Replicate
Expand Down Expand Up @@ -113,8 +113,9 @@ def main():
device = f"cuda:{rank}"
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)
mesh = init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = mesh.get_rank() // tp_size

mesh = veDeviceMesh.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = get_rank() // tp_size
else:
rank = 0
ddp_rank = 0
Expand Down Expand Up @@ -329,8 +330,7 @@ def get_lr(it):
# Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
# training loop
X, Y = get_batch("train") # fetch the very first batch
t0 = time.time()
Expand Down Expand Up @@ -363,8 +363,7 @@ def get_lr(it):
# When iter_num == 0, the training does not start sotoptimizer state is empty,
# Don't save checkpoint
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
if iter_num == 0 and eval_only:
break

Expand Down
2 changes: 0 additions & 2 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@ tqdm
optree
accelerate
transformers==4.37.2
grpcio
grpcio-tools
8 changes: 3 additions & 5 deletions python/vescale/checkpoint/planner/vescale/vescale_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
find_state_dict_object,
)

from vescale.dtensor.device_mesh import mesh_resources
from vescale.devicemesh_api import veDeviceMesh

logger: logging.Logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -190,8 +190,6 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
A function for creating local saving plan for saving checkpoint
"""
requests = []
device_mesh = mesh_resources.get_current_mesh()
dp_device_mesh = device_mesh["DP"]
for fqn, obj in state_dict.items():
# Since DTensor supports submesh, adding extra check to ensure _create_write_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
Expand Down Expand Up @@ -232,7 +230,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.irecv,
tensor=recv_tensor,
peer=k,
group=dp_device_mesh.get_dim_groups(0),
group=veDeviceMesh.get_data_parallel_dim_groups(),
)
recv_tensors[k] = recv_tensor
p2p_ops.append(recv_op)
Expand All @@ -243,7 +241,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.isend,
tensor=obj.local_tensor,
peer=writer_rank,
group=dp_device_mesh.get_dim_groups(0),
group=veDeviceMesh.get_data_parallel_dim_groups(),
)
p2p_ops.append(send_op)

Expand Down
4 changes: 2 additions & 2 deletions python/vescale/checkpoint/storage/checkpoint_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_megatron_tp_group(world_size, pp_size, tp_size, dp_size, cur_rank) -> t

def _deduce_parallel_plan_by_device_mesh(mesh: DeviceMesh):
"""make rank to megatron tp_rank, pp_rank map"""
# FIXME(cery.69) : current only support data parallel is 1
# FIXME : current only support data parallel is 1
# allways parallel in last dim
tp_size = mesh.size()
# for rank = pp_rank * tp_size + tp_rank
Expand Down Expand Up @@ -261,7 +261,7 @@ def find_device_mesh(st):
torch.save(optim, os.path.join(megatron_optim_dict_path, "optim.pt"))
del st["optim"]
torch.save(st, megatron_save_file)
# FIXME(cery.69): support dp not 1
# FIXME: support dp not 1
return st


Expand Down
18 changes: 18 additions & 0 deletions python/vescale/devicemesh_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from .device_mesh_api import veDeviceMesh
Loading