diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 85f45a3e..60fed663 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,5 @@ from .linear import Linear, OpLinear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear -from .parallel_embedding import Projection, VPProjection -from .parallel_linear_func import OpParallelLinear \ No newline at end of file +from .parallel_embedding import VPEmbedding +from .parallel_linear_func import OpParallelLinear diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 39aa147b..43e7397d 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -8,35 +8,8 @@ from bmtrain.distributed import all_reduce, all_gather from .parallel_linear_func import OpParallelLinear -class Projection(bmt.DistributedModule): - def __init__( - self, - vocab_size: int, - embedding_size: int, - dtype: torch.dtype = torch.half, - init_mean: float = 0.0, - init_std: float = 1, - ): - super().__init__() - - self.dim_model = embedding_size - self.weight = bmt.DistributedParameter( - torch.empty(vocab_size, embedding_size, dtype=dtype), - init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), - ) - - def forward(self, x: torch.Tensor): - """ - Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. - Args: - x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection - Returns: - :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. - """ # noqa: E501 - logits = F.linear(x, self.weight) - return logits -class VPProjection(bmt.DistributedModule): +class VPEmbedding(bmt.DistributedModule): def __init__( self, vocab_size: int, @@ -59,12 +32,11 @@ def __init__( tp_mode=True, ) - def forward(self, x: torch.Tensor): - """ - Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. - Args: - x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection - Returns: - :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. - """ # noqa: E501 - return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) \ No newline at end of file + def forward(self, x: torch.Tensor, projection=False): + if not projection: + weight = all_gather(self.weight, comm=config['tp_comm']).flatten(0,1) + out = F.embedding(x, weight) + return out + else: + x = bmt.distributed.all_gather(x, comm=bmt.config['tp_comm']).view(x.shape[0], -1, x.shape[-1]) + return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) diff --git a/example/layers/attention.py b/example/layers/attention.py index 32497bcb..0f5155d4 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -41,8 +41,7 @@ def forward(self, mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) ) -> torch.Tensor: - batch_size, seq_q, dim_model = hidden_q.size() - seq_kv = hidden_kv.size(1) + batch_size = hidden_q.size()[0] assert hidden_q.data_ptr() == hidden_kv.data_ptr() @@ -54,14 +53,16 @@ def forward(self, True, False, False, None ) + hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1]) h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) - #batch_size will changed in TensorParallel - batch_size = h_v.shape[0] else: h_q : torch.Tensor = self.project_q(hidden_q) h_k : torch.Tensor = self.project_k(hidden_kv) h_v : torch.Tensor = self.project_v(hidden_kv) + seq_q = h_q.size()[1] + seq_kv = h_k.size(1) + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) @@ -84,10 +85,6 @@ def forward(self, if position_bias is not None: score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) - if config['tp_size'] > 1: - with torch.no_grad(): - mask = all_gather(mask, config['tp_comm']).flatten(0,1) - score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -108,8 +105,11 @@ def forward(self, h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() h_out = h_out.view(batch_size, seq_q, -1) + if config['tp_size'] > 1: + h_out = h_out.view(h_out.shape[0] * bmt.config["tp_size"], -1, h_out.shape[-1]) attn_out = self.project_out(h_out) + return attn_out diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 7cda1bb9..4cbff59b 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -28,7 +28,7 @@ def forward(self, x = self.ln_ff(hidden) x = self.ff(x) - hidden = hidden + x + hidden = hidden + x.view_as(hidden) return hidden diff --git a/example/models/gpt.py b/example/models/gpt.py index 4596167c..ed604382 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -14,8 +14,8 @@ def __init__(self, self.max_distance = max_distance - if config['tp_size'] > 1: - self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + if config["tp_size"] > 1: + self.word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype) else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) @@ -50,17 +50,15 @@ def forward(self, mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) - + if config["tp_size"] > 1: + input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]] + pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]] out = self.pos_emb(pos) + self.word_emb(input) # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - - if config['tp_size'] > 1: - logits = self.word_emb.projection(out) - else: - logits = self.word_emb(out, projection=True) + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") return logits diff --git a/example/train.py b/example/train.py index 8aaf65e4..d5906a06 100644 --- a/example/train.py +++ b/example/train.py @@ -36,8 +36,10 @@ def main(): batch_size = 2 seq_len = 512 + world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"] + r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"] - for i in range(bmt.world_size()): + for i in range(world_size): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() enc_input = sent[:, :-1].long().cuda() @@ -49,7 +51,7 @@ def main(): torch.full_like(targets, -100, dtype=torch.long) ) - if i == bmt.rank(): + if i == r: break if config['tp_size'] > 1: @@ -82,7 +84,7 @@ def main(): batch, seq_len, vocab_out_size = logits.size() if config['tp_size'] > 1: - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) else: loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))