Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

101 no regularization grad #104

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
20 changes: 12 additions & 8 deletions torchdms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,21 @@ def beta_coefficients(self):
def regularization_loss(self):
"""L1 penalize single mutant effects, and pre-latent interaction
weights."""
penalty = 0.0
if self.beta_l1_coefficient > 0.0:
penalty += self.beta_l1_coefficient * self.latent_layer.weight[
if self.beta_l1_coefficient > 0:
# NOTE: slice excludes interaction weights in latent layer
beta_l1_penalty = self.beta_l1_coefficient * self.latent_layer.weight[
:, : self.input_size
].norm(1)
else:
beta_l1_penalty = 0.0
if self.interaction_l1_coefficient > 0.0:
for interaction_layer in self.layers[: self.latent_idx]:
penalty += self.interaction_l1_coefficient * getattr(
self, interaction_layer
).weight.norm(1)
return penalty
interaction_l1_penalty = self.interaction_l1_coefficient * sum(
getattr(self, interaction_layer).weight.norm(1)
for interaction_layer in self.layers[: self.latent_idx]
)
else:
interaction_l1_penalty = 0.0
return beta_l1_penalty + interaction_l1_penalty


class Independent(TorchdmsModel):
Expand Down
36 changes: 36 additions & 0 deletions torchdms/test/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Testing model module
"""
import torch
from torchdms.model import FullyConnected, identity


def test_regularization_loss():
"""Test regularization loss gradient."""

model = FullyConnected(
10,
[10, 2],
[None, identity],
[None],
None,
beta_l1_coefficient=torch.rand(1),
interaction_l1_coefficient=torch.rand(1),
)

loss = model.regularization_loss()
loss.backward()

for layer in model.layers:
layer_type = layer.split("_")[0]
if layer_type == "latent":
weight = model.latent_layer.weight[:, : model.input_size]
grad_weight = model.latent_layer.weight.grad[:, : model.input_size]
penalty = model.beta_l1_coefficient
elif layer_type == "interaction":
weight = getattr(model, layer).weight
grad_weight = getattr(model, layer).weight.grad
penalty = model.interaction_l1_coefficient
else:
continue
assert torch.equal(grad_weight, penalty * weight.sign())