diff --git a/torchdms/model.py b/torchdms/model.py index 7890df15..3e169a85 100644 --- a/torchdms/model.py +++ b/torchdms/model.py @@ -362,17 +362,21 @@ def beta_coefficients(self): def regularization_loss(self): """L1 penalize single mutant effects, and pre-latent interaction weights.""" - penalty = 0.0 - if self.beta_l1_coefficient > 0.0: - penalty += self.beta_l1_coefficient * self.latent_layer.weight[ + if self.beta_l1_coefficient > 0: + # NOTE: slice excludes interaction weights in latent layer + beta_l1_penalty = self.beta_l1_coefficient * self.latent_layer.weight[ :, : self.input_size ].norm(1) + else: + beta_l1_penalty = 0.0 if self.interaction_l1_coefficient > 0.0: - for interaction_layer in self.layers[: self.latent_idx]: - penalty += self.interaction_l1_coefficient * getattr( - self, interaction_layer - ).weight.norm(1) - return penalty + interaction_l1_penalty = self.interaction_l1_coefficient * sum( + getattr(self, interaction_layer).weight.norm(1) + for interaction_layer in self.layers[: self.latent_idx] + ) + else: + interaction_l1_penalty = 0.0 + return beta_l1_penalty + interaction_l1_penalty class Independent(TorchdmsModel): diff --git a/torchdms/test/test_model.py b/torchdms/test/test_model.py new file mode 100644 index 00000000..ac69c80a --- /dev/null +++ b/torchdms/test/test_model.py @@ -0,0 +1,36 @@ +""" +Testing model module +""" +import torch +from torchdms.model import FullyConnected, identity + + +def test_regularization_loss(): + """Test regularization loss gradient.""" + + model = FullyConnected( + 10, + [10, 2], + [None, identity], + [None], + None, + beta_l1_coefficient=torch.rand(1), + interaction_l1_coefficient=torch.rand(1), + ) + + loss = model.regularization_loss() + loss.backward() + + for layer in model.layers: + layer_type = layer.split("_")[0] + if layer_type == "latent": + weight = model.latent_layer.weight[:, : model.input_size] + grad_weight = model.latent_layer.weight.grad[:, : model.input_size] + penalty = model.beta_l1_coefficient + elif layer_type == "interaction": + weight = getattr(model, layer).weight + grad_weight = getattr(model, layer).weight.grad + penalty = model.interaction_l1_coefficient + else: + continue + assert torch.equal(grad_weight, penalty * weight.sign())