Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 23, 2025
1 parent 8db7cf6 commit 4c896d8
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union

from brevitas.utils.torch_utils import WeightBiasWrapper
import torch
from torch import Tensor
from torch.nn import Module
Expand All @@ -21,6 +20,7 @@
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import WeightBiasWrapper

__all__ = [
'Transform',
Expand Down
14 changes: 8 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.rotation_utils import ScaleWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook, WeightBiasWrapper
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas.utils.torch_utils import WeightBiasWrapper

# External optional dependency
try:
Expand Down Expand Up @@ -647,7 +648,7 @@ def _no_equalize():
module=module,
tensor_name="bias",
transform_module=ScaleWeightParametrization(
scaling_factor=partial_scale.view_as(module.bias),is_sink=False)))
scaling_factor=partial_scale.view_as(module.bias), is_sink=False)))
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
if fuse_scaling:
Expand All @@ -656,8 +657,8 @@ def _no_equalize():
module=module,
tensor_name="weight",
transform_module=ScaleWeightParametrization(
scaling_factor=torch.reshape(partial_scale,
src_broadcast_size),is_sink=False)))
scaling_factor=torch.reshape(partial_scale, src_broadcast_size),
is_sink=False)))
for name, (module, axis) in sink_axes.items():
module_device = module.weight.device
sink_broadcast_size = [1] * module.weight.ndim
Expand All @@ -677,7 +678,8 @@ def _no_equalize():
module=module,
tensor_name="weight",
transform_module=ScaleWeightParametrization(
scaling_factor=torch.reshape(partial_scaling, sink_broadcast_size), is_sink=True)))
scaling_factor=torch.reshape(partial_scaling, sink_broadcast_size),
is_sink=True)))

# If a module has `offload_params` attribute, we must offload the weights following that method
for name in (region.srcs_names + region.sinks_names):
Expand Down Expand Up @@ -722,7 +724,7 @@ def _equalize(
scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max)
else:
scale_factor_max = scale_factor_region_max

for r in rewriters:
r.apply(model)
if threshold is not None and scale_factor_max < threshold:
Expand Down
8 changes: 2 additions & 6 deletions src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,12 @@ class ScaleWeightParametrization(torch.nn.Module):
the tensor
"""

def __init__(
self,
scaling_factor: Tensor,
is_sink : bool
) -> None:
def __init__(self, scaling_factor: Tensor, is_sink: bool) -> None:
super().__init__()
self.scaling_factor = scaling_factor
self.is_sink = is_sink

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
# Reciprocal is done on the fly as to preserve the tie between scale and its reciprocal
scale = torch.reciprocal(self.scaling_factor) if self.is_sink else self.scaling_factor
return tensor * scale
return tensor * scale
2 changes: 2 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ class StopFwdException(Exception):
"""Used to throw and catch an exception to stop traversing the graph."""
pass


# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None


class TupleSequential(Sequential):

def output(self, mod, input):
Expand Down

0 comments on commit 4c896d8

Please sign in to comment.