You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have studied the code of TensorRT-LLM, and honestly, it is complex because of the generalization level involved and to be honest it does confuse me a lot.
So I wanted to know what kind of attention should I use to deploy a new model for the following code? it seems it is a variant of HuggingFace's attention by looking at the code that has been changed slightly, I'm not sure about it
My observation is that: If I use BERT attention(BertAttention) in Tensorrt-LLM and disable RoPE in the pytorch code, the output values in both PyTorch and TensorRT-llm are perfectly identical. However, I do not know which version of attention in TensorRT-LLM would be Bert Attention that also has rotary embedding. Can anybody please explain that to me?
classAttnProcessor:
def__init__(self):
passdef__call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722mask: bool["b n"] |None=None, # noqa: F722rope=None, # rotary position embedding
) ->torch.FloatTensor:
batch_size=x.shape[0]
# `sample` projections.query=attn.to_q(x)
key=attn.to_k(x)
value=attn.to_v(x)
# apply rotary position embeddingifropeisnotNone:
freqs, xpos_scale=ropeq_xpos_scale, k_xpos_scale= (xpos_scale, xpos_scale**-1.0) ifxpos_scaleisnotNoneelse (1.0, 1.0)
query=apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key=apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attentioninner_dim=key.shape[-1]
head_dim=inner_dim//attn.headsquery=query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key=key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value=value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the paddingifmaskisnotNone:
attn_mask=maskattn_mask=attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'attn_mask=attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask=Nonex=F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x=x.transpose(1, 2).reshape(batch_size, -1, attn.heads*head_dim)
x=x.to(query.dtype)
# linear projx=attn.to_out[0](x)
# dropoutx=attn.to_out[1](x)
ifmaskisnotNone:
mask=mask.unsqueeze(-1)
x=x.masked_fill(~mask, 0.0)
returnx
This is the code for RoPE:
classRotaryEmbedding(Module):
def__init__(
self,
dim,
use_xpos=False,
scale_base=512,
interpolation_factor=1.,
base=10000,
base_rescale_factor=1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning# has some connection to NTK literature# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/base*=base_rescale_factor** (dim/ (dim-2))
inv_freq=1./ (base** (torch.arange(0, dim, 2).float() /dim))
self.register_buffer('inv_freq', inv_freq)
assertinterpolation_factor>=1.self.interpolation_factor=interpolation_factorifnotuse_xpos:
self.register_buffer('scale', None)
returnscale= (torch.arange(0, dim, 2) +0.4*dim) / (1.4*dim)
self.scale_base=scale_baseself.register_buffer('scale', scale)
defforward_from_seq_len(self, seq_len):
device=self.inv_freq.devicet=torch.arange(seq_len, device=device)
returnself.forward(t)
@autocast('cuda', enabled=False)defforward(self, t):
max_pos=t.max() +1ift.ndim==1:
t=rearrange(t, 'n -> 1 n')
freqs=torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) /self.interpolation_factorfreqs=torch.stack((freqs, freqs), dim=-1)
freqs=rearrange(freqs, '... d r -> ... (d r)')
ifnotexists(self.scale):
returnfreqs, 1.power= (t- (max_pos//2)) /self.scale_basescale=self.scale**rearrange(power, 'n -> n 1')
scale=torch.stack((scale, scale), dim=-1)
scale=rearrange(scale, '... d r -> ... (d r)')
returnfreqs, scaledefrotate_half(x):
x=rearrange(x, '... (d r) -> ... d r', r=2)
x1, x2=x.unbind(dim=-1)
x=torch.stack((-x2, x1), dim=-1)
returnrearrange(x, '... d r -> ... (d r)')
@autocast('cuda', enabled=False)defapply_rotary_pos_emb(t, freqs, scale=1):
rot_dim, seq_len, orig_dtype=freqs.shape[-1], t.shape[-2], t.dtypefreqs=freqs[:, -seq_len:, :]
scale=scale[:, -seq_len:, :] ifisinstance(scale, torch.Tensor) elsescaleift.ndim==4andfreqs.ndim==3:
freqs=rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-Jt, t_unrotated=t[..., :rot_dim], t[..., rot_dim:]
t= (t*freqs.cos() *scale) + (rotate_half(t) *freqs.sin() *scale)
out=torch.cat((t, t_unrotated), dim=-1)
returnout.type(orig_dtype)
I tried using Attention class using the following params:
self.attn = Attention(
local_layer_idx=0, # For testing
hidden_size=hidden_size,
num_attention_heads=num_heads,
rotary_embedding_base=10000.0, # default
rotary_embedding_percentage=1.0, # portion of the channels using rope
attention_mask_type=AttentionMaskType.causal,
position_embedding_type=PositionEmbeddingType.rope_gptj,
tp_group=None,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
bias=False,
# other parameters
)
But it is complaining about:
AssertionError: rotary_inv_freq and embed_positions_for_gpt_attention must be provided.
and out of these two, I only have weight for rotary_inv_freq. Maybe I'm not aware of how to supply the model with embed_positions_for_gpt_attention .
Thanks in advance
The text was updated successfully, but these errors were encountered:
I have studied the code of TensorRT-LLM, and honestly, it is complex because of the generalization level involved and to be honest it does confuse me a lot.
So I wanted to know what kind of attention should I use to deploy a new model for the following code? it seems it is a variant of HuggingFace's attention by looking at the code that has been changed slightly, I'm not sure about it
My observation is that: If I use BERT attention(
BertAttention
) in Tensorrt-LLM and disable RoPE in the pytorch code, the output values in both PyTorch and TensorRT-llm are perfectly identical. However, I do not know which version of attention in TensorRT-LLM would be Bert Attention that also has rotary embedding. Can anybody please explain that to me?This is the code for RoPE:
I tried using
Attention
class using the following params:But it is complaining about:
and out of these two, I only have weight for
rotary_inv_freq
. Maybe I'm not aware of how to supply the model withembed_positions_for_gpt_attention
.Thanks in advance
The text was updated successfully, but these errors were encountered: