Skip to content

Commit

Permalink
update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
qiauil committed Aug 15, 2024
1 parent f63d074 commit fa5adb8
Show file tree
Hide file tree
Showing 21 changed files with 3,294 additions and 147 deletions.
2 changes: 1 addition & 1 deletion ConFIG/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#usr/bin/python3
# -*- coding: UTF-8 -*-
import torch
from typing import Optional,Sequence,Union
from typing import Optional,Sequence,Union,Tuple
159 changes: 118 additions & 41 deletions ConFIG/grad_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,40 @@
def ConFIG_update_double(grad_1:torch.Tensor,grad_2:torch.Tensor,
weight_model:WeightModel=EqualWeight(),
length_model:LengthModel=ProjectionLength(),
losses:Optional[Sequence]=None):
losses:Optional[Sequence]=None)-> torch.Tensor:
"""
ConFIG update for two gradients where no inverse calculation is needed.
Args:
grad_1 (torch.Tensor): The first gradient.
grad_2 (torch.Tensor): The second gradient.
weight_model (WeightModel, optional): The weight model to determine the coefficients. Defaults to EqualWeight().
length_model (LengthModel, optional): The length model to rescale the target vector. Defaults to ProjectionLength().
losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
losses (Optional[Sequence], optional): The losses associated with the gradients.
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
you can set this value as None. Defaults to None.
Returns:
torch.Tensor: The rescaled length of the best direction.
torch.Tensor: The final update gradient.
Examples:
```python
from ConFIG.grad_operator import ConFIG_update_double
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
grads=[] # we record gradients rather than losses
for loss_fn in [loss_fn1, loss_fn2]:
optimizer.zero_grad()
loss_i=loss_fn(input_i)
loss_i.backward()
grads.append(get_gradient_vector(network)) #get loss-specfic gradient
g_config=ConFIG_update_double(grads) # calculate the conflict-free direction
apply_gradient_vector(network) # set the condlict-free direction to the network
optimizer.step()
```
"""
with torch.no_grad():
Expand All @@ -49,20 +70,42 @@ def ConFIG_update(
weight_model:WeightModel=EqualWeight(),
length_model:LengthModel=ProjectionLength(),
use_latest_square:bool=True,
losses:Optional[Sequence]=None
):
losses:Optional[Sequence]=None)-> torch.Tensor:
"""
Performs the standard ConFIG update step.
Args:
grads (Sequence[torch.Tensor]): The gradients to update.
weight_model (WeightModel, optional): The weight model to use for calculating weights. Defaults to EqualWeight().
length_model (LengthModel, optional): The length model to use for rescaling the length of the target vector. Defaults to ProjectionLength().
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction. Defaults to True.
losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
losses (Optional[Sequence], optional): The losses associated with the gradients.
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
you can set this value as None. Defaults to None.
Returns:
torch.Tensor: The rescaled length of the target vector.
torch.Tensor: The final update gradient.
Examples:
```python
from ConFIG.grad_operator import ConFIG_update
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
grads=[] # we record gradients rather than losses
for loss_fn in loss_fns:
optimizer.zero_grad()
loss_i=loss_fn(input_i)
loss_i.backward()
grads.append(get_gradient_vector(network)) #get loss-specfic gradient
g_config=ConFIG_update(grads) # calculate the conflict-free direction
apply_gradient_vector(network) # set the condlict-free direction to the network
optimizer.step()
```
"""
if not isinstance(grads,torch.Tensor):
grads=torch.stack(grads)
Expand All @@ -80,7 +123,7 @@ def ConFIG_update(

class GradientOperator:
"""
A class that represents a gradient operator.
A base class that represents a gradient operator.
Methods:
calculate_gradient: Calculates the gradient based on the given gradients and losses.
Expand All @@ -91,13 +134,16 @@ class GradientOperator:
def __init__(self):
pass

def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)-> torch.Tensor:
"""
Calculates the gradient based on the given gradients and losses.
Args:
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
losses (Optional[Sequence]): The losses (default: None).
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
losses (Optional[Sequence], optional): The losses associated with the gradients.
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
you can set this value as None. Defaults to None.
Returns:
torch.Tensor: The calculated gradient.
Expand All @@ -108,14 +154,17 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
"""
raise NotImplementedError("calculate_gradient method must be implemented")

def update_gradient(self, network:torch.nn.Module, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
def update_gradient(self, network: torch.nn.Module, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)-> None:
"""
Updates the gradient of the network based on the calculated gradient.
Calculate the gradient and apply the gradient to the network.
Args:
network: The network.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
losses (Optional[Sequence]): The losses (default: None).
network (torch.nn.Module): The target network.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
losses (Optional[Sequence], optional): The losses associated with the gradients.
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
you can set this value as None. Defaults to None.
Returns:
None
Expand All @@ -126,13 +175,36 @@ def update_gradient(self, network:torch.nn.Module, grads: Union[torch.Tensor,Seq

class ConFIGOperator(GradientOperator):
"""
ConFIGOperator class represents a gradient operator for ConFIG algorithm.
Operator for the ConFIG algorithm.
Args:
weight_model (WeightModel, optional): The weight model to be used for calculating the gradient. Defaults to EqualWeight().
length_model (LengthModel, optional): The length model to be used for calculating the gradient. Defaults to ProjectionLength().
allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient. Defaults to True.
use_latest_square (bool, optional): Whether to use the latest square for calculating the gradient. Defaults to True.
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient.
If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
Examples:
```python
from ConFIG.grad_operator import ConFIGOperator
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
operator=ConFIGOperator() # initialize operator
for input_i in dataset:
grads=[]
for loss_fn in loss_fns:
optimizer.zero_grad()
loss_i=loss_fn(input_i)
loss_i.backward()
grads.append(get_gradient_vector(network))
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
optimizer.step()
```
"""

def __init__(self,
Expand All @@ -146,13 +218,16 @@ def __init__(self,
self.allow_simplified_model = allow_simplified_model
self.use_latest_square = use_latest_square

def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor:
"""
Calculates the gradient using the ConFIG algorithm.
Args:
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
losses (Optional[Sequence], optional): The losses associated with the gradients.
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
you can set this value as None. Defaults to None.
Returns:
torch.Tensor: The calculated gradient.
Expand Down Expand Up @@ -186,16 +261,17 @@ class PCGradOperator(GradientOperator):
"""


def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor:
"""
Calculates the gradient using the ConFIG algorithm.
Calculates the gradient using the PCGrad algorithm.
Args:
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
Returns:
torch.Tensor: The calculated gradient.
torch.Tensor: The calculated gradient using PCGrad method.
"""
if not isinstance(grads,torch.Tensor):
grads=torch.stack(grads)
Expand All @@ -212,7 +288,7 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],

class IMTLGOperator(GradientOperator):
"""
PCGradOperator class represents a gradient operator for IMTLG algorithm.
PCGradOperator class represents a gradient operator for IMTL-G algorithm.
@inproceedings{
liu2021towards,
Expand All @@ -226,16 +302,17 @@ class IMTLGOperator(GradientOperator):
"""


def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None) ->torch.Tensor:
"""
Calculates the gradient using the ConFIG algorithm.
Calculates the gradient using the IMTL-G algorithm.
Args:
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
Returns:
torch.Tensor: The calculated gradient.
torch.Tensor: The calculated gradient using IMTL-G method.
"""
if not isinstance(grads,torch.Tensor):
grads=torch.stack(grads)
Expand Down
49 changes: 33 additions & 16 deletions ConFIG/length_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class LengthModel:
"""
This class represents a length model.
The base class for length model.
Methods:
get_length: Calculates the length based on the given parameters.
Expand All @@ -17,35 +17,36 @@ def get_length(self,
target_vector:Optional[torch.Tensor]=None,
unit_target_vector:Optional[torch.Tensor]=None,
gradients:Optional[torch.Tensor]=None,
losses:Optional[Sequence]=None):
losses:Optional[Sequence]=None)-> Union[torch.Tensor, float]:
"""
Calculates the length based on the given parameters.
Calculates the length based on the given parameters. Not all parameters are required.
Args:
target_vector: The target vector.
unit_target_vector: The unit target vector.
gradients: The gradients.
losses: The losses.
target_vector (Optional[torch.Tensor]): The final update gradient vector.
unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
losses (Optional[Sequence]): The losses.
Returns:
The calculated length.
Union[torch.Tensor, float]: The calculated length.
"""
raise NotImplementedError("This method must be implemented by the subclass.")

def rescale_length(self,
target_vector:torch.Tensor,
gradients:Optional[torch.Tensor]=None,
losses:Optional[Sequence]=None):
losses:Optional[Sequence]=None)->torch.Tensor:
"""
Rescales the length of the target vector based on the given parameters.
It calls the get_length method to calculate the length and then rescales the target vector.
Args:
target_vector: The target vector.
gradients: The gradients.
losses: The losses.
target_vector (torch.Tensor): The final update gradient vector.
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
losses (Optional[Sequence]): The losses.
Returns:
The rescaled length.
torch.Tensor: The rescaled target vector.
"""
unit_target_vector = unit_vector(target_vector)
return self.get_length(target_vector=target_vector,
Expand All @@ -64,7 +65,23 @@ def __init__(self):
def get_length(self, target_vector:Optional[torch.Tensor]=None,
unit_target_vector:Optional[torch.Tensor]=None,
gradients:Optional[torch.Tensor]=None,
losses:Optional[Sequence]=None):
losses:Optional[Sequence]=None)->torch.Tensor:
"""
Calculates the length based on the given parameters. Not all parameters are required.
Args:
target_vector (Optional[torch.Tensor]): The final update gradient vector.
One of the `target_vector` or `unit_target_vector` parameter need to be provided.
unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
One of the `target_vector` or `unit_target_vector` parameter need to be provided.
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
losses (Optional[Sequence]): The losses. Not used in this model.
Returns:
Union[torch.Tensor, float]: The calculated length.
"""
if gradients is None:
raise ValueError("The ProjectLength model requires gradients information.")
if unit_target_vector is None:
unit_target_vector = unit_vector(target_vector)
return torch.sum(torch.stack([torch.dot(grad_i,unit_target_vector) for grad_i in gradients]))
Loading

0 comments on commit fa5adb8

Please sign in to comment.