Skip to content

Commit

Permalink
rename distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
a710128 committed Mar 25, 2022
1 parent ed6d54c commit 17e9425
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,5 @@ cython_debug/
**/.DS_Store

**/log
**/*.qdrep
**/*.qdrep
!bmtrain/dist
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y
ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5
ENV BMP_AVX512=1
ADD other_requirements.txt other_requirements.txt
RUN pip3 install -r other_requirements.txt
RUN pip3 install bmtrain
RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ADD . .
RUN python3 setup.py install

WORKDIR /root
ADD example example
2 changes: 1 addition & 1 deletion bmtrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from . import inspect
from . import lr_scheduler
from . import loss
from . import dist
from . import distributed
File renamed without changes.
5 changes: 2 additions & 3 deletions bmtrain/dist/ops.py → bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Literal
import torch
from ..global_var import config
from ..nccl import allGather as ncclAllGather
Expand Down Expand Up @@ -30,7 +29,7 @@ def all_gather(x : torch.Tensor):

class OpAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input : torch.Tensor, op: Literal['sum', 'prod', 'max', 'min', 'avg']):
def forward(ctx, input : torch.Tensor, op : str):
if not input.contiguous():
input = input.contiguous()
output = torch.empty( input.size(), dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -64,7 +63,7 @@ def backward(ctx, grad_output):
else:
return grad_output * ctx.saved_tensors[0], None

def all_reduce(x : torch.Tensor, op: Literal['sum', 'prod', 'max', 'min', 'avg']):
def all_reduce(x : torch.Tensor, op : str = "sum"):
assert x.is_cuda
return OpAllReduce.apply(x, op)

Expand Down
5 changes: 4 additions & 1 deletion other_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
tqdm
cpm_kernels>=1.0.11
jieba
jieba
tensorboard
setuptools_rust
transformers
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def get_avx_flags():
install_requires=[
"torch>=1.10",
"numpy",
"tensorboard"
],
ext_modules=[
CUDAExtension('bmtrain.nccl._C', [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def main():
bmt.init_distributed()
x = torch.full((1,), bmt.rank() + 1, dtype=torch.half, device="cuda").requires_grad_(True)
y = bmt.dist.all_reduce(x, "prod").view(-1)
y = bmt.distributed.all_reduce(x, "prod").view(-1)
bmt.print_rank(y)
loss = (y * y).sum() / 2
loss.backward()
Expand Down

0 comments on commit 17e9425

Please sign in to comment.