-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlayers.py
138 lines (115 loc) · 4.21 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
class MLP(nn.Module):
"""
Basic MLP (multi-layer perceptron) layer. Dropout is neglected.
"""
def __init__(
self,
d_model: int,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.lin_0 = nn.Linear(d_model, 4 * d_model, device=device, dtype=dtype)
self.act_fn = nn.GELU()
self.lin_1 = nn.Linear(4 * d_model, d_model, device=device, dtype=dtype)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = self.lin_0(inputs)
x = self.act_fn(x)
x = self.lin_1(x)
return x
class AllReduceFwdIdentityBwd(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> torch.Tensor:
inputs = inputs.clone()
dist.all_reduce(inputs, group=group)
return inputs
@staticmethod
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]:
return grad_outputs, None
class IdentityFwdAllReduceBwd(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> torch.Tensor:
ctx.group = group
return inputs
@staticmethod
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]:
grad_outputs = grad_outputs.clone()
dist.all_reduce(grad_outputs, group=ctx.group)
return grad_outputs, None
class LinearShardedOutputs(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
group: dist.ProcessGroup,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
sharded_out_features, remainder = divmod(out_features, group.size())
assert not remainder, "out_features must be divisible by the ProcessGroup size"
super().__init__(
in_features=in_features,
out_features=sharded_out_features,
device=device,
dtype=dtype,
)
self.group = group
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Wrap the unsharded inputs for backwards-pass correctness.
x = IdentityFwdAllReduceBwd.apply(inputs, self.group)
x = super().forward(x)
return x
class LinearShardedInputs(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
group: dist.ProcessGroup,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
sharded_in_features, remainder = divmod(in_features, group.size())
assert not remainder, "in_features must be divisible by the ProcessGroup size"
super().__init__(
in_features=sharded_in_features,
out_features=out_features,
device=device,
dtype=dtype,
)
self.group = group
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = inputs @ self.weight.T
# Wrap the mat-mul in an all-reduce forwards-pass correctness.
x = AllReduceFwdIdentityBwd.apply(x, self.group)
# Crucial: add the bias _after_ the all-reduce.
x = x + self.bias
return x
class MLPTP(MLP):
"""
Basic Tensor Parallel MLP (multi-layer perceptron) layer. Dropout is neglected.
"""
def __init__(
self,
d_model: int,
group: Optional[dist.ProcessGroup] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
nn.Module.__init__(self)
# Fallback to the WORLD process group, if None provided
group = group or dist.group.WORLD
self.lin_0 = LinearShardedOutputs(
d_model, 4 * d_model, group=group, device=device, dtype=dtype
)
self.act_fn = nn.GELU()
self.lin_1 = LinearShardedInputs(
4 * d_model, d_model, group=group, device=device, dtype=dtype
)