Skip to content

Commit

Permalink
Feat (brevitas_examples): Po2 per channel float OCP weight quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 3, 2024
1 parent 8fff8ea commit 1ab9a0e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ class PerChannelPoTScaling8bit(ExtendedInjector):
"""
"""
scaling_per_output_type = ScalingPerOutputType.CHANNEL
restrict_scaling_type = RestrictValueType.FP
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
bit_width = 8
restrict_value_float_to_int_impl = CeilSte


class PerTensorPoTScaling8bit(ExtendedInjector):
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint
Expand Down Expand Up @@ -131,6 +132,8 @@
'per_group': {
'sym': MXFloat8e4m3Weight}},
'mse': {
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXFloat8e4m3WeightMSE}}}},
'float_fnuz': {
Expand Down
9 changes: 9 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.base import HQOWeightZeroPoint
from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.base import PerChannelPoTScaling8bit
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO
Expand Down Expand Up @@ -141,3 +144,9 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3ActPerTensorFloat):
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte
proxy_class = ActFloatQuantProxyFromInjector


class Fp8e4m3OCPWeightPerChannelFixedPointMSE(MSESymmetricScale,
PerChannelPoTScaling8bit,
Fp8e4m3OCPWeightPerChannelFloat):
pass

0 comments on commit 1ab9a0e

Please sign in to comment.