Skip to content

Commit

Permalink
ctype load nccl pypi
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed Aug 9, 2023
1 parent a23e151 commit 538cdc3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
6 changes: 6 additions & 0 deletions a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def a():
print(b())
a()
def b():
return 1

23 changes: 12 additions & 11 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from .utils import print_dict
import ctypes
from .global_var import config
from . import nccl

try:
from . import nccl
except:
from .utils import load_nccl_pypi
load_nccl_pypi()
from .synchronize import synchronize


def init_distributed(
init_method : str = "env://",
seed : int = 0,
Expand Down Expand Up @@ -52,16 +59,9 @@ def init_distributed(
port = os.environ["MASTER_PORT"]
master = addr+":"+port
timeout = datetime.timedelta(seconds=1800)
try:
rendezvous_iterator = dist.rendezvous(
init_method, rank, world_size, timeout=timeout
)
except RuntimeError:
import nvidia.nccl
path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib")
for file_so in os.listdir(path):
if file_so.endswith(".so"):
ctypes.CDLL(os.path.join(path, file_so))
rendezvous_iterator = dist.rendezvous(
init_method, rank, world_size, timeout=timeout
)

store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
Expand Down Expand Up @@ -173,3 +173,4 @@ def get_group_rank(self,group_name):

def is_initialized() -> bool:
return config["initialized"]

13 changes: 13 additions & 0 deletions bmtrain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
from typing import Any, Dict, Iterable, Optional
from .global_var import config
import os
import ctypes

ALIGN = 4
ROW_WIDTH = 60
Expand All @@ -18,6 +20,17 @@ def check_torch_version(version_str):
current_version_int = current_version_int_arr[0] * 10000 + current_version_int_arr[1] * 100 + current_version_int_arr[2]
return current_version_int - version_int

def load_nccl_pypi():
try:
import nvidia.nccl
except:
print("Run pip install nvidia-nccl-cu11 >=2.14.3 first")
path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib")
for file_so in os.listdir(path):
if file_so.endswith(".so"):
ctypes.CDLL(os.path.join(path, file_so))


def round_up(x, d):
return (x + d - 1) // d * d

Expand Down

0 comments on commit 538cdc3

Please sign in to comment.