From 52cfffd892ca60688d396dd52b3a3104c48b820e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:16:09 +0000 Subject: [PATCH] Feat (equalize): enable parametrized rotations (#1148) --- src/brevitas/graph/base.py | 129 +++++++++ src/brevitas/graph/equalize.py | 166 +++++++++--- src/brevitas/graph/hadamard.py | 3 +- src/brevitas/utils/rotation_utils.py | 49 ++++ src/brevitas_examples/llm/main.py | 7 +- tests/brevitas/graph/equalization_fixtures.py | 60 +++++ tests/brevitas/graph/test_equalization.py | 252 ++++++++++++++++++ tests/brevitas/graph/test_transforms.py | 76 ++++++ tests/brevitas_examples/test_llm.py | 85 ++++++ 9 files changed, 781 insertions(+), 46 deletions(-) create mode 100644 src/brevitas/utils/rotation_utils.py diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..d1631f34e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -5,9 +5,13 @@ from abc import abstractmethod import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Optional, Type, Union import torch +from torch import Tensor from torch.nn import Module +from torch.nn import Parameter +import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides from brevitas.fx import GraphModule @@ -15,6 +19,7 @@ from brevitas.fx import Node from brevitas.graph.utils import * from brevitas.utils.python_utils import islambda +from brevitas.utils.rotation_utils import RotationWeightParametrization __all__ = [ 'Transform', @@ -174,6 +179,130 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model +class ModuleInstanceRegisterParametrization(Transform): + r"""Transform to register a parametrization to a given parameter of a + module. + + Args: + module (nn.Module): module on which to register the + parametrization + tensor_name: (str): name of the :class:`torch.nn.Parameter` of + module which is to be parametrized + transform_module (nn.Module): the parametrization to + register + """ + + def __init__(self, module: Module, tensor_name: str, transform_module: Module) -> None: + self.module = module + self.tensor_name = tensor_name + self.transform_module = transform_module + + # TODO: Unify inferfaces with ModuleInstanceToModuleInstance for + # compatibility with fix_rewriter + @property + def old_module_instance(self): + return self.module + + @old_module_instance.setter + def old_module_instance(self, old_module_instance): + self.module = old_module_instance + + def apply(self, model: GraphModule) -> GraphModule: + for module in model.modules(): + if module is self.module: + # register the parametrization to module + parametrize.register_parametrization( + module, self.tensor_name, self.transform_module) + break + return model + + +class ModuleInstanceTransformTensor(Transform): + r"""Transform to transform in-place a given parameter of a module + + Args: + module (nn.Module): parent module of the parameter to be transformed + tensor_name (str): name of the :class:`torch.nn.Parameter` of + module which is to be transformed + transform_module (nn.Module): module defining the transformation to apply + to the tensor + """ + + def __init__( + self, + module: Module, + tensor_name: str, + transform_module: Module, + ): + self.module = module + self.tensor_name = tensor_name + self.transform_module = transform_module + + # TODO: Unify inferfaces with ModuleInstanceToModuleInstance for + # compatibility with fix_rewriter + @property + def old_module_instance(self): + return self.module + + @old_module_instance.setter + def old_module_instance(self, old_module_instance): + self.module = old_module_instance + + def apply(self, model: GraphModule) -> GraphModule: + for module in model.modules(): + if module is self.module: + # This check is needed to apply the change in the parameters + # when the model is offloaded + # TODO: Move outside the apply function + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + tensor = getattr(module, self.tensor_name).data + tensor = self.transform_module(tensor) + # Modify the weights in-place + setattr(module, self.tensor_name, torch.nn.Parameter(tensor)) + + if hasattr(module, 'offload_params'): + module.offload_params(module) + break + return model + + +class ModuleInstanceWrapModule(Transform): + r"""Transform to replace a module by a wrapper module which has the original + one as a submodule + + Args: + old_module_instance (nn.Module): module to be wrapped + wrapper_class (type): class of the wrapper for old_module_instance + module_attribute (str): name of the parameter to pass the original + module to the constructor of wrapper_class + kwargs_wrapper (dict, optional): dictionary with the constructor + arguments for wrapper_class + """ + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model + + class ModuleInstanceToModuleInstance(Transform): def __init__(self, old_module_instance, new_module_instance): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7cbe38c6a..744130664 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -15,19 +15,23 @@ import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter +from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceTransformTensor +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda +from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule @@ -35,6 +39,8 @@ from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule from brevitas.nn.quant_scale_bias import ScaleBias +from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.rotation_utils import RotationWeightParametrization from brevitas.utils.torch_utils import KwargsForwardHook # External optional dependency @@ -1299,8 +1305,19 @@ def random_orthogonal_matrix(size): return q -def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): +def _apply_rotate( + model: nn.Module, + regions: List[Region], + full_rotation_method='had', + fuse_rotations: bool = True, + apply_inplace_rotations: bool = True): rewriters = [] + # First, rotations on orphan sinks are applied so the order in which rotations are + # applied is consistent, irrespective of the value of fuse_rotations. This is due to + # the fact that parametrizations need to be registered, once all the in-place + # operations have taken place + regions = [region for region in regions if len(region.srcs) == 0] + [ + region for region in regions if len(region.srcs) > 0] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1311,6 +1328,14 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device + elif not insert_rotation_module and not fuse_rotations: + # If the model is distributed across GPUs, the device will be + # not be the same for all of the parameters, so explicit moves + # to the same device as the weights need to be added + device = next(model.parameters()).device + rot_mat = random_hadamard_matrix(hidden_dim, device) + K = None + rot_func = _apply_ort_device else: try: # Build hadamard rotation matrix @@ -1326,57 +1351,112 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= print("Skipping layers") continue + # Cast rotation matrix to the weight dtype + if rot_mat is not None: + dtype = next(model.parameters()).dtype + rot_mat = rot_mat.to(dtype=dtype) + # If the rotation is not fused, redefine as a Parameter, to enable its optimization + if not insert_rotation_module and not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - axis = _get_output_axis(module) - weight = module.weight.data - - if axis == 0: - rotated_weight = rot_func(weight.t(), rot_mat, K).t() - _update_weights(module, rotated_weight, 'weight') - elif axis == 1: - rotated_weight = rot_func(weight, rot_mat, K) - _update_weights(module, rotated_weight, 'weight') - else: - raise RuntimeError("Not supported yet") - - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + # Rotate "bias" if present + tensor_names_axis = [("weight", _get_output_axis(module))] + ([ + ("bias", 1)] if getattr(module, 'bias', None) is not None else []) + # If rotations are fused, transform is applied directly onto the tensor + rewriter_class = ModuleInstanceTransformTensor if fuse_rotations else ModuleInstanceRegisterParametrization + # Obtain rewriters for applying the rotations + for tensor_name, axis in tensor_names_axis: + rewriter = rewriter_class( + module=module, + tensor_name=tensor_name, + transform_module=RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + K=K, + )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - axis = _get_input_axis(module) - weight = module.weight.data - - if axis == 1: - rotated_weight = rot_func(weight, rot_mat, K) - _update_weights(module, rotated_weight, 'weight') - elif axis == 0: - rotated_weight = rot_func(weight.t(), rot_mat, K).t() - _update_weights(module, rotated_weight, 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) - + # Only "weight" is rotated + tensor_names_axis = [("weight", _get_input_axis(module))] + # If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor + rewriter_class = ModuleInstanceTransformTensor if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization + # Obtain rewriters for applying the rotations + for tensor_name, axis in tensor_names_axis: + rewriter = rewriter_class( + module=module, + tensor_name=tensor_name, + transform_module=RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + K=K, + )) + rewriters.append(rewriter) + # Replace by RotatedModule in orphan sink if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": rot_mat, "k": K}) rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) + if apply_inplace_rotations: + for r in rewriters: + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters +# This function is adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py +def _untie_parameters_with_parametrizations(model: torch.nn.Module): + # get ALL model parameters and their names + all_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=False)} + + # get ONLY unique named parameters, + # if parameter is tied and have multiple names, it will be included only once + no_duplicate_named_parameters = { + name: param for name, param in model.named_parameters(remove_duplicate=True)} + + # the difference of the two sets will give us the tied parameters + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) + + for tied_param_name in tied_param_names: + tied_param_name_split = tied_param_name.split(".") + # The names of the original parameters after registering the parametrization + # have the format "prefix.parametrizations.tensor_name.original", e.g. + # "model.layer.parametrizations.weight.original". This allows to identify + # which subset of tied parameters are original tied parameters of the module + if len(tied_param_name_split) >= 3 and tied_param_name_split[ + -3] == "parametrizations" and tied_param_name_split[-1] == "original": + # If that is the case, retrieve the parent module + parent_module = recurse_getattr(model, ".".join(tied_param_name_split[:-1])) + # And set to a new parameter, thus breaking the tie + setattr(parent_module, "original", nn.Parameter(all_named_parameters[tied_param_name])) + + return model + + +def _fuse_rotations(model: nn.Module) -> nn.Module: + # First of all, parameters that have parametrizations need to be untied + model = _untie_parameters_with_parametrizations(model) + # Then, parametrizations can be safely removed + for module in model.modules(): + # Names of the tensors that can potentially be parametrized + tensor_names = ["weight", "bias"] + # Remove parametrizations from each tensor + for tensor_name in tensor_names: + if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: + parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True) + return model + + def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) if next_module.bias is not None: diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index 27bf1b4ae..59430057c 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -106,7 +106,8 @@ def random_hadamard_matrix(size, device): Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return matmul_hadU(Q).to(device) + # Set to float32 for consistency with random_orthogonal_matrix and get_hadK + return matmul_hadU(Q).to(device).float() def matmul_hadU_cuda(X, hadK, K): diff --git a/src/brevitas/utils/rotation_utils.py b/src/brevitas/utils/rotation_utils.py new file mode 100644 index 000000000..6a79d1cc3 --- /dev/null +++ b/src/brevitas/utils/rotation_utils.py @@ -0,0 +1,49 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Callable, Optional + +import torch +from torch import Tensor + + +class RotationWeightParametrization(torch.nn.Module): + r"""Rotates a tensor by a specified axis + + Args: + rot_mat (Tensor): orthogonal matrix by which to rotate the tensor + rot_func (Callable): function to apply the rotation. The first + argument corresponds to the tensor to be rotated, while the + second specifies the rotation matrix. The third argument (K) is + useful when rotating by an Hadamard matrix and it corresponds + to the dimensionality of the matrix up to a power of two, + i.e. dim=(2**p)*K. See brevitas.graph.hadamard.get_hadK for details + axis (int): axis by which to rotate the tensor + K (int, optional): if rot_mat is an Hadamard matrix, K is the highest + divisor of the dimensionality of the matrix, such that K, itself, + is not divisible by 2 + """ + + def __init__( + self, + rot_mat: Callable[[Tensor, Tensor, Optional[int]], Tensor], + rot_func: Callable, + axis: int, + K: Optional[int] = None, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.axis = axis + self.K = K + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + + if self.axis == 0: + tensor = self.rot_func(tensor.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + tensor = self.rot_func(tensor, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + + return tensor diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4f03ba087..21641a819 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -22,6 +22,7 @@ from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph import load_quant_model_mode +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import functional_quantization_mode @@ -97,9 +98,11 @@ def fused_rotation_no_fx(model, calibration_loader, args): sdpa_regions=args.rotation_sdpa_regions) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') - for r in rewriters: - r.apply(model) + # The weights between model and new_model are tied, so this check prevents + # rotating the weights twice + if isinstance(r, ModuleInstanceWrapModule): + r.apply(model) remove_hooks(new_model) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 035cdaadd..014637414 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import functools + from packaging import version import pytest import pytest_cases @@ -528,3 +530,61 @@ def forward(self, x): rotation_fixtures = fixture_union( 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + +IN_FEATURES = 12 +RESIDUAL_MODEL_REGION_DICTS = [ + { + "srcs": ["embedding", "block1_linear2", "block2_linear2"], + "sinks": ["block1_linear1", "block2_linear1", "head"],}, + { + "srcs": ["block1_linear1"], "sinks": ["block1_linear2"]}, + { + "srcs": [], "sinks": ["block2_linear2"]},] + + +class BlockResidualModel(nn.Module): + + def __init__(self, is_tied: bool = False) -> None: + super().__init__() + self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + + self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + + self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + self.act = nn.SiLU() + self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + + self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + if is_tied: + self.head.weight = self.embedding.weight + + def forward(self, x): + x = self.embedding(x) + r = x + x = self.block1_linear1(x) + x = self.block1_linear2(x) + r + r = x + x = self.block2_linear1(x) + x = self.act(x) + x = self.block2_linear2(x) + r + x = self.head(x) + return x + + +@pytest_cases.fixture +def block_residual_model(): + return functools.partial(BlockResidualModel, is_tied=False) + + +@pytest_cases.fixture +def block_residual_model_tied(): + return functools.partial(BlockResidualModel, is_tied=True) + + +list_of_rotation_fixtures = [ + "block_residual_model", + "block_residual_model_tied",] + +rotation_model = fixture_union( + 'rotation_model', list_of_rotation_fixtures, ids=list_of_rotation_fixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index afb8636e4..41edf0752 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,41 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from functools import partial +from functools import reduce +import itertools +from unittest.mock import patch +import pytest import torch +import torch.nn.utils.parametrize as parametrize from torchvision import models +from brevitas import torch_version from brevitas.fx import symbolic_trace +from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.equalize import _apply_had_device +from brevitas.graph.equalize import _apply_ort_device +from brevitas.graph.equalize import _apply_rotate from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _fuse_rotations +from brevitas.graph.equalize import _get_input_axis +from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.equalize import EqualizationIndexes from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine +from brevitas.graph.equalize import random_orthogonal_matrix +from brevitas.graph.equalize import Region +from brevitas.graph.hadamard import get_hadK from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from brevitas.nn.equalized_layer import RotatedModule +from brevitas.utils.rotation_utils import RotationWeightParametrization from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -276,3 +296,235 @@ def test_models(rotation_fixtures, partial_had): if partial_had: last_weight_new = model.linear_2.layer.weight.data assert not torch.allclose(last_weight, last_weight_new) + + +@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") +def test_composition_unfused_rotations(N): + torch.manual_seed(SEED) + + for rotation_flags in itertools.product([False, True], repeat=N): + + in_features = 5 + module = nn.Linear(in_features=in_features, out_features=in_features) + rot_module = copy.deepcopy(module) + + # Sample input to pass through the block + sample_input = torch.rand((1, in_features),) + # Composite rotation matrices + rot_mat_input = torch.eye(in_features) + rot_mat_output = torch.eye(in_features) + + for is_source in rotation_flags: + # Generate a random matrix + rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32) + + # Aggregate rotation matrices + if is_source: + rot_mat_output = rot_mat_output @ rot_mat + else: + rot_mat_input = rot_mat_input @ rot_mat + + # Compose rotation modules + parametrize.register_parametrization( + rot_module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), + )) + if is_source: + parametrize.register_parametrization( + rot_module, + "bias", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + axis=1, + )) + + # If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1} + # If the node is a source, the output is multiplied by the rotation matrix o <- oQ + gt_output = module(sample_input @ rot_mat_input.t()) @ rot_mat_output + rot_output = rot_module(sample_input) + + # Verify that the rotation operations were computed correctly + assert torch.allclose(gt_output, rot_output, atol=ATOL) + + +# This method is almost the same as brevitas.graph.equalize.random_orthogonal_matrix, except for the +# possibility of passing a generator, that enables controlling the random matrices that are generated +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal +# matrices generated are the same. +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + Args: + size (int): The size of the matrix (size x size). + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +# Auxiliar method to convert a dictionary of sources/sinks into a valid region +def _instantiate_region(region_dict, model) -> Region: + if len(region_dict["srcs"]) > 0: + sorted_srcs = dict( + sorted({src: EqualizationIndexes(0, IN_FEATURES, 0) for src in region_dict["srcs"] + }.items())) + sorted_sinks = dict( + sorted({sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"] + }.items())) + else: + sorted_srcs = dict() + sorted_sinks = dict( + sorted({sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"] + }.items())) + sorted_acts = tuple() + return Region( + srcs=sorted_srcs, sinks=sorted_sinks, acts=sorted_acts, name_to_module=model._modules) + + +# Auxiliar function to compare the weights of module instances belonging to classes_to_compare +def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Linear,)): + tensor_names = ["weight", "bias"] + for name_module_fused, module_fused in model_fused.named_modules(): + if isinstance(module_fused, classes_to_compare): + module_unfused = reduce(getattr, [model_unfused] + name_module_fused.split(".")) + for tensor_name in tensor_names: + if hasattr(module_fused, tensor_name) and getattr(module_fused, + tensor_name) is not None: + assert torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=0.0, rtol=0.0), f"Tensor {tensor_name} does not match for module {name_module_fused}" + + +@requires_pt_ge('2.3.1') +@pytest_cases.parametrize( + 'mask', + itertools.product([False, True], repeat=3), + ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el])) +@pytest_cases.parametrize('full_rotation_method', ['ort', 'had']) +@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']) +@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["unfused", "fused"]) +@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"]) +def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_rotations, use_fx): + # Instantiate a residual model for which a collection of regions is available + model = rotation_model() + device = torch.device("cuda") if device == 'cuda' else torch.device("cpu") + model.to(device) + # Sample input to pass through the models + sample_inputs = torch.rand(size=(5, IN_FEATURES)).to(device) + # Collect only a subset of regions to be applied + regions_dicts = [ + region_dict for mask_element, + region_dict in zip(mask, RESIDUAL_MODEL_REGION_DICTS) if mask_element] + # Use FX model if requested + if use_fx: + graph_model, _ = torch._dynamo.export(model)(sample_inputs) + # The module names in the original model need to be mapped to the ones + # in graph_model + map_model_graph = {} + assigned_graph_modules = set() + for graph_module_name, graph_module in graph_model.named_modules(): + if hasattr(graph_module, "weight"): + for name, module in model.named_modules(): + # The check name not in map_model_graph prevents the assignment to the same module + # when tied parameters are present + if name not in map_model_graph and graph_module_name not in assigned_graph_modules and hasattr( + module, "weight") and graph_module.weight is module.weight: + map_model_graph[name] = graph_module_name + assigned_graph_modules.add(graph_module_name) + # Replace the names of the modules in sources/sinks by the names of the modules in the FX model + regions_dicts = [{ + k: list(map(lambda x: map_model_graph[x], v)) + for k, v in region_dict.items()} + for region_dict in regions_dicts] + # Rotation will be applied on the FX model + model = graph_model + + # Deepcopy the models as parameters are going to be modified in-place + rotated_model_unfused = copy.deepcopy(model) + rotated_model_fused = copy.deepcopy(model) + + # Generator to control the random orthogonal matrices generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = torch.Generator() + generator_clone.set_state(generator.get_state()) + + # Apply rotations on the model with unfused rotations + regions_unfused = list( + map(lambda x: _instantiate_region(x, rotated_model_unfused), regions_dicts)) + if full_rotation_method == 'had': + # _apply_ort_device is patched to ensure that the hadamard matrices in hadamard.pt are used, instead of + # the random ones generated by random_hadamard_matrices + with patch('brevitas.graph.equalize._apply_ort_device', + lambda tensor, + had_K, + K: _apply_had_device( + tensor, get_hadK(had_K.shape[0])[0], get_hadK(had_K.shape[0])[1])): + rewriters = _apply_rotate( + rotated_model_unfused, + regions_unfused, + full_rotation_method=full_rotation_method, + fuse_rotations=False) + elif full_rotation_method == 'ort': + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + rewriters = _apply_rotate( + rotated_model_unfused, + regions_unfused, + full_rotation_method=full_rotation_method, + fuse_rotations=False) + # Register parametrizations after calling _apply_rotate, as these are not inmediately registered since they alter the structure of the + # model, thus potentially causing a crash if the model is offloaded + for r in rewriters: + if isinstance(r, ModuleInstanceRegisterParametrization): + rotated_model_unfused = r.apply(rotated_model_unfused) + # Apply rotations on the model with fused rotations + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + regions_fused = list( + map(lambda x: _instantiate_region(x, rotated_model_fused), regions_dicts)) + _apply_rotate( + rotated_model_fused, + regions_fused, + full_rotation_method=full_rotation_method, + fuse_rotations=True) + + # Compute outputs for each model + model_output = model(sample_inputs) + rotated_model_unfused_output = rotated_model_unfused(sample_inputs) + rotated_model_fused_output = rotated_model_fused(sample_inputs) + + # Verify that the correct number of unique rotation matrices were included. Orphan sinks (len(region_dict["srcs"]) == 0) do not + # an attached parametrization + assert sum([len(region_dict["srcs"]) > 0 for region_dict in regions_dicts]) == sum([ + "rot_mat" in name for name, + _ in rotated_model_unfused.named_parameters(remove_duplicate=True)]) + # Verify that RotatedModules were added appropiately + for rotated_model in [rotated_model_fused, rotated_model_unfused]: + assert sum([len(region_dict["srcs"]) == 0 for region_dict in regions_dicts]) == sum([ + isinstance(module, RotatedModule) for module in rotated_model.modules()]) + # Optionally fuse the rotations + if fuse_rotations: + rotated_model_unfused = _fuse_rotations(rotated_model_unfused) + # Verify that no parametrizations remain after fusing + for module in rotated_model_unfused.modules(): + assert not parametrize.is_parametrized(module) + # Outputs should match for rotated and unrotated models + assert torch.allclose(model_output, rotated_model_fused_output, atol=ATOL) + assert torch.allclose( + rotated_model_unfused_output, rotated_model_fused_output, atol=0.0, rtol=0.0) + # Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations + # Verify that weights match between the fused and unfused model + compare_model_weights(rotated_model_fused, rotated_model_unfused) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index 875d5a52c..2d5c7a78f 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -5,6 +5,7 @@ from packaging import version import pytest +import pytest_cases import torch from torch import nn from torchvision import models @@ -16,10 +17,15 @@ from brevitas.graph import MeanMethodToAdaptiveAvgPool2d from brevitas.graph import MergeBatchNorm from brevitas.graph import MethodToModule +from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.base import ModuleInstanceTransformTensor +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import ModuleToModuleByInstance from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d +from brevitas.nn.equalized_layer import RotatedModule +from brevitas.utils.rotation_utils import RotationWeightParametrization SEED = 123456 INPUT_SIZE = (1, 3, 224, 224) @@ -290,3 +296,73 @@ def forward(self, x): kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1} model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model) assert model.conv.stride == (2, 2) + + +def test_module_instance_register_parametrization(): + + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2, 2, bias=False) + + def forward(self, x): + return self.linear(x) + + class ZeroParametrization(nn.Module): + + def forward(self, x): + return torch.zeros_like(x) + + model = TestModel() + model = ModuleInstanceRegisterParametrization(model.linear, "weight", + ZeroParametrization()).apply(model) + assert torch.all(model.linear.weight == 0.) + + +def test_module_instance_wrap_module(): + + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2, 2, bias=False) + + def forward(self, x): + return self.linear(x) + + model = TestModel() + model = ModuleInstanceWrapModule( + model.linear, RotatedModule, "layer", { + "had_mat": None, "k": None}).apply(model) + assert isinstance(model.linear, RotatedModule) + + +@pytest_cases.parametrize("axis", [0, 1], ids=lambda axis: f"axis={axis}") +def test_fuse_rotation_weights(axis): + + def rot_func(weight, ort, K): + return torch.matmul(weight, ort) + + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2, 2, bias=False) + + def forward(self, x): + return self.linear(x) + + rot_mat = torch.rand(2, 2) + model_fused = TestModel() + model_unfused = TestModel() + model_unfused.linear.weight.data = model_fused.linear.weight.data + + model_fused = ModuleInstanceTransformTensor( + model_fused.linear, "weight", RotationWeightParametrization(rot_mat, rot_func, axis, + None)).apply(model_fused) + model_unfused = ModuleInstanceRegisterParametrization( + model_unfused.linear, + "weight", + RotationWeightParametrization(rot_mat, rot_func, axis, None)).apply(model_unfused) + assert torch.all(model_fused.linear.weight == model_unfused.linear.weight) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index ca1c7cda7..f6db73924 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -740,3 +740,88 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): quant_ppl = quant_ppl.detach().cpu().numpy() assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "llama_fused_rotation_ort", + "llama_fused_rotation_ort_no_orphan", + "llama_fused_rotation_had", + "llama_fused_rotation_had_no_orphan", + "llama_layerwise",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": True, + "rotation_mode": "ort", + "float_ppl": 33238.8984375, + "quant_ppl": 33232.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": False, + "rotation_mode": "ort", + "float_ppl": 33238.8984375, + "quant_ppl": 33420.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": True, + "rotation_mode": "had", + "float_ppl": 33238.8984375, + "quant_ppl": 33290.48046875}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": False, + "rotation_mode": "had", + "float_ppl": 33238.8984375, + "quant_ppl": 33204.80859375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "layerwise", + "float_ppl": 33238.8984375, + "quant_ppl": 33446.734375},]) +def rotation_ppl_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + + +@requires_pt_ge('2.4') +def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): + if platform.system() == "Windows": + pytest.skip("Skipping dynamo + windows") + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = rotation_ppl_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"