Skip to content

Commit

Permalink
temp2
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent de1ffad commit cca6a90
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
3 changes: 1 addition & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import warnings

from brevitas.nn import ScaledDotProductAttention
import packaging
import packaging.version
import torch
Expand All @@ -31,6 +30,7 @@
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn import ScaledDotProductAttention
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import functional_rotate_input
from brevitas.nn.equalized_layer import INPUT_NAMES
Expand Down Expand Up @@ -1509,7 +1509,6 @@ def find_sink(node):
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
# m.pre_process_v = partial(functional_rotate_input, transpose=True)
return regions

def apply(self,
Expand Down
17 changes: 9 additions & 8 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
import math
from typing import Optional, Tuple, Union

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.functional as F

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat

Expand All @@ -59,7 +59,7 @@

class ScaledDotProductAttention(Module):

def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity):
def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity):
super().__init__()
self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
Expand Down Expand Up @@ -113,7 +113,7 @@ def forward(
return F.scaled_dot_product_attention(
query=self.pre_process_q(query),
key=self.pre_process_k(key),
value=value,#self.pre_process_v(value),
value=self.pre_process_v(value),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
Expand All @@ -124,7 +124,9 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(),
pre_process_q=identity,
pre_process_k=identity,
pre_process_v=identity,
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
Expand Down Expand Up @@ -211,15 +213,14 @@ def forward(
else:
attn_bias += attn_mask
query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value)
q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor)
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
attn_weight += attn_bias
attn_weight = self.softmax_input_quant(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
# attn_weight = self.pre_process_q(attn_weight)
# attn_weight = self.attn_output_weights_quant(attn_weight)
attn_weight = self.attn_output_weights_quant(attn_weight)
attn_output = attn_weight @ self.v_quant(value)
attn_output = self.sdpa_output_quant(attn_output)
return attn_output

0 comments on commit cca6a90

Please sign in to comment.