Skip to content

Commit

Permalink
Add support of Falcon models (7b, 40b, 180b) to DeepSpeed-FastGen (#4790
Browse files Browse the repository at this point in the history
)
  • Loading branch information
arashb authored Dec 12, 2023
1 parent b186816 commit a7900bc
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 1 deletion.
1 change: 1 addition & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ We currently support the following model architectures in this alpha release of
* [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2)
* [Mistral](https://huggingface.co/models?other=mistral)
* [OPT](https://huggingface.co/models?other=opt)
* [Falcon](https://huggingface.co/models?other=falcon)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OPTPolicy,
Llama2Policy,
MistralPolicy,
FalconPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -104,6 +105,8 @@ def build_hf_engine(path: str,
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "falcon":
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ void launch_kv_rotary_kernel(T* kv_cache,
DISPATCH_KV_ROTARY_IMPL(5, 128)
DISPATCH_KV_ROTARY_IMPL(8, 64)
DISPATCH_KV_ROTARY_IMPL(8, 128)
DISPATCH_KV_ROTARY_IMPL(16, 64)
DISPATCH_KV_ROTARY_IMPL(16, 128)
DISPATCH_KV_ROTARY_IMPL(29, 64)
DISPATCH_KV_ROTARY_IMPL(29, 128)
DISPATCH_KV_ROTARY_IMPL(35, 64)
DISPATCH_KV_ROTARY_IMPL(35, 128)
DISPATCH_KV_ROTARY_IMPL(36, 64)
DISPATCH_KV_ROTARY_IMPL(36, 128)
DISPATCH_KV_ROTARY_IMPL(71, 64)
DISPATCH_KV_ROTARY_IMPL(71, 128)
}

#define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .llama_v2 import *
from .opt import *
from .mistral import *
from .falcon import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .falcon_policy import FalconPolicy
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ...model_implementations.common_parameters import *
from ...model_implementations.layer_container_base import LayerContainer
'''
# HF Falcon 7b model looks like this:
FalconForCausalLM(
(transformer): FalconModel(
(word_embeddings): Embedding(65024, 4544)
(h): ModuleList(
(0-31): 32 x FalconDecoderLayer(
(self_attention): FalconAttention(
(maybe_rotary): FalconRotaryEmbedding()
(query_key_value): FalconLinear(in_features=4544, out_features=4672, bias=False)
(dense): FalconLinear(in_features=4544, out_features=4544, bias=False)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(mlp): FalconMLP(
(dense_h_to_4h): FalconLinear(in_features=4544, out_features=18176, bias=False)
(act): GELU(approximate='none')
(dense_4h_to_h): FalconLinear(in_features=18176, out_features=4544, bias=False)
)
(input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
)
)
(ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4544, out_features=65024, bias=False)
)
'''


class FalconTransformerContainer(LayerContainer):
"""
Transformer layer container for the Falcon model.
"""
qkv_w: FusedQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: MLP1Parameter
mlp_2_w: MLP2Parameter
ln_attn_gamma: NormParameter
ln_attn_beta: NormParameter

PARAM_MAPPING = {
"self_attention.query_key_value.weight": "qkv_w.params",
"self_attention.dense.weight": "attn_out_w.params",
"mlp.dense_h_to_4h.weight": "mlp_1_w.params",
"mlp.dense_4h_to_h.weight": "mlp_2_w.params",
"input_layernorm.weight": "ln_attn_gamma.params",
"input_layernorm.bias": "ln_attn_beta.params",
}


class FalconNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Falcon model.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm_gamma: NormParameter
final_norm_beta: NormParameter

PARAM_MAPPING = {
"transformer.word_embeddings.weight": "word_emb.params",
"transformer.ln_f.weight": "final_norm_gamma.params",
"transformer.ln_f.bias": "final_norm_beta.params",
"lm_head.weight": "word_unembed.params",
}


'''
# HF Falcon 40b model looks like this:
FalconForCausalLM(
(transformer): FalconModel(
(word_embeddings): Embedding(65024, 8192)
(h): ModuleList(
(0-59): 60 x FalconDecoderLayer(
(self_attention): FalconAttention(
(maybe_rotary): FalconRotaryEmbedding()
(query_key_value): FalconLinear(in_features=8192, out_features=9216, bias=False)
(dense): FalconLinear(in_features=8192, out_features=8192, bias=False)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(mlp): FalconMLP(
(dense_h_to_4h): FalconLinear(in_features=8192, out_features=32768, bias=False)
(act): GELU(approximate='none')
(dense_4h_to_h): FalconLinear(in_features=32768, out_features=8192, bias=False)
)
(ln_attn): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
(ln_mlp): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
)
)
(ln_f): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=8192, out_features=65024, bias=False)
)
'''


class FalconNewArchTransformerContainer(LayerContainer):
"""
Transformer layer container for the Falcon model.
"""
qkv_w: GQAMegatronQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: MLP1Parameter
mlp_2_w: MLP2Parameter
ln_attn_gamma: NormParameter
ln_attn_beta: NormParameter
ln_mlp_gamma: NormParameter
ln_mlp_beta: NormParameter

PARAM_MAPPING = {
"self_attention.query_key_value.weight": "qkv_w.params",
"self_attention.dense.weight": "attn_out_w.params",
"mlp.dense_h_to_4h.weight": "mlp_1_w.params",
"mlp.dense_4h_to_h.weight": "mlp_2_w.params",
"ln_attn.weight": "ln_attn_gamma.params",
"ln_attn.bias": "ln_attn_beta.params",
"ln_mlp.weight": "ln_mlp_gamma.params",
"ln_mlp.bias": "ln_mlp_beta.params",
}
206 changes: 206 additions & 0 deletions deepspeed/inference/v2/model_implementations/falcon/falcon_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional, Tuple

import torch

import deepspeed.comm as dist

from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from ...model_implementations import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...ragged import RaggedBatchWrapper

from .falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer


class FalconInferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for Llama-2 models.
"""

_non_transformer: Optional[FalconNonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[FalconTransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
"""
Properties inherited from `DSInferenceModelBase`
"""

@property
def max_sequence_length(self) -> int:
return self._config.max_seq_length

"""
Properties inherited from `DSTransformerModelBase`
"""

@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.hidden_size

@property
def vocab_size(self) -> int:
return self._config.vocab_size

@property
def head_size(self) -> int:
return self.model_dim // self.n_heads

@property
def n_heads(self) -> int:
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
return 4 * self._config.hidden_size

@property
def n_heads_kv(self) -> int:
return self._config.num_kv_heads if (self._config.new_decoder_architecture
or not self._config.multi_query) else 1

@property
def activation_dtype(self) -> DtypeEnum:
if self._config.torch_dtype == torch.float16:
return DtypeEnum.fp16
elif self._config.torch_dtype == torch.bfloat16:
return DtypeEnum.bf16
else:
raise NotImplementedError("Only fp16 and bf16 are supported")

@property
def mlp_activation_fn(self) -> ActivationType:
return ActivationType.GELU

@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.LayerNorm

@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

"""
Forward implementations
"""

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.
Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)

if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
optimization to fuse the layer norm of the next layer into the current layer.
Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
assert self.config.parallel_attn, "Only parallel attention implementation is supported"

cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

attn_ln_out = hidden_states
attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=None)
attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info)
attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=None)

if self.config.new_decoder_architecture:
residual, mlp_ln_out = self.norm(residual,
None,
gamma=cur_params.ln_mlp_gamma,
beta=cur_params.ln_mlp_beta)
else:
mlp_ln_out = hidden_states

mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=None)
mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=None)

mlp_output.add_(attention_output)

if self.tp_size > 1:
dist.all_reduce(mlp_output, group=self._base_mp_group)

if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, mlp_output = self.norm(residual,
mlp_output,
next_params.ln_attn_gamma,
beta=next_params.ln_attn_beta)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
residual.add_(mlp_output)

return residual, mlp_output

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm_gamma,
beta=self._non_transformer.final_norm_beta)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))

dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)

full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))

return full_logits
else:
return logits

def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual,
None,
gamma=self._transformer[0].ln_attn_gamma,
beta=self._transformer[0].ln_attn_beta)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)
Loading

0 comments on commit a7900bc

Please sign in to comment.