Skip to content

Commit

Permalink
Feat (ptq): linear learned round for LLMs (#1064)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
pablomlago and Giuseppe5 authored Dec 3, 2024
1 parent 1ca8d7a commit 8e0c399
Show file tree
Hide file tree
Showing 25 changed files with 2,552 additions and 267 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-vision.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
accelerate
torchvision
tqdm
55 changes: 41 additions & 14 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

import brevitas
from brevitas import config
from brevitas.core.function_wrapper.ops_ste import TensorClampSte
from brevitas.core.utils import SliceTensor
from brevitas.function.ops_ste import floor_ste
from brevitas.function.ops_ste import round_ste


class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule):
Expand All @@ -28,12 +30,17 @@ def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float =
self.learned_round_gamma = learned_round_gamma

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(x)
def forward(self, p: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(p)
p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma
p = torch.clamp(p, 0.0, 1.0)
if not self.training:
return p > 0.5
return p

def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundSigmoid(brevitas.jit.ScriptModule):
"""
Expand All @@ -47,10 +54,37 @@ def __init__(self, learned_round_temperature: float = 1.) -> None:
self.learned_round_temperature = learned_round_temperature

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(x / self.learned_round_temperature)
def forward(self, p: torch.Tensor) -> torch.Tensor:
if not self.training:
return p > 0
p = torch.sigmoid(p / self.learned_round_temperature)
return p

@brevitas.jit.script_method
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundIdentity(brevitas.jit.ScriptModule):
"""
Implementation for LearnedRound learned parameter
Adapted from https://arxiv.org/abs/2309.05516
"""

def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.tensor_clamp = TensorClampSte()
self.upper_lower_bound = brevitas.jit.Attribute(0.5, float)

def forward(self, p: torch.Tensor) -> torch.Tensor:
return self.tensor_clamp(
p,
min_val=torch.tensor(-self.upper_lower_bound).type_as(p),
max_val=torch.tensor(self.upper_lower_bound).type_as(p))

def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)


class LearnedRoundSte(brevitas.jit.ScriptModule):
"""
Expand All @@ -72,17 +106,10 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = self.p_forward()
p = self.learned_round_impl(self.value)
p = self.tensor_slicer(p)
return floor_ste(x) + p.to(x.dtype)

def p_forward(self):
# In eval mode, performs true quantization, otherwise "soft" quantization
if not self.training:
p = (self.value > 0)
else:
p = self.learned_round_impl(self.value)
return p
p = (p.to(x.dtype)).view_as(x)
return self.learned_round_impl.round_forward(x, p)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
from brevitas.graph.gpxq import SUPPORTED_TCONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.utils.torch_utils import StopFwdException


class GPFQ(GPxQ):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from brevitas import torch_version
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.utils.torch_utils import StopFwdException


class GPTQ(GPxQ):
Expand Down
4 changes: 0 additions & 4 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
SUPPORTED_CONV_OP = (qnn.QuantConv1d, qnn.QuantConv2d, qnn.QuantConv3d, *SUPPORTED_TCONV_OP)


class StopFwdException(Exception):
pass


@dataclass
class LayerHandler:
layer_names: Set = field(default_factory=set)
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/inject/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LearnedRoundImplType(AutoName):
"""
HARD_SIGMOID = auto()
SIGMOID = auto()
IDENTITY = auto()


class ScalingImplType(AutoName):
Expand Down
Loading

0 comments on commit 8e0c399

Please sign in to comment.