Skip to content

Commit

Permalink
Merge pull request #303 from VectorInstitute/masked_layers_bug_fix
Browse files Browse the repository at this point in the history
Converting to masked layers bug fix
  • Loading branch information
yc7z authored Dec 11, 2024
2 parents 2ebb636 + 6380629 commit 4a09d5a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 47 deletions.
71 changes: 37 additions & 34 deletions fl4health/model_bases/masked_layers/masked_layers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,46 @@

def convert_to_masked_model(original_model: nn.Module) -> nn.Module:
"""
Given a model, convert every one of its linear or convolutional layer to a masked layer
of the same kind, as defined in the classes above.
Given a model, convert every one of its layers to a masked layer
of the same kind, if applicable.
"""
masked_model = copy.deepcopy(original_model)
for name, module in original_model.named_modules():
# Mask nn.Linear modules
if isinstance(module, nn.Linear) and not isinstance(module, MaskedLinear):
setattr(masked_model, name, MaskedLinear.from_pretrained(module))

# Mask convolutional modules (1d, 2d, and 3d)
elif isinstance(module, nn.Conv1d) and not isinstance(module, MaskedConv1d):
setattr(masked_model, name, MaskedConv1d.from_pretrained(module))
elif isinstance(module, nn.Conv2d) and not isinstance(module, MaskedConv2d):
setattr(masked_model, name, MaskedConv2d.from_pretrained(module))
elif isinstance(module, nn.Conv3d) and not isinstance(module, MaskedConv3d):
setattr(masked_model, name, MaskedConv3d.from_pretrained(module))

# Mask transposed convolutional modules (1d, 2d, 3d)
elif isinstance(module, nn.ConvTranspose1d) and not isinstance(module, MaskedConvTranspose1d):
setattr(masked_model, name, MaskedConvTranspose1d.from_pretrained(module))
elif isinstance(module, nn.ConvTranspose2d) and not isinstance(module, MaskedConvTranspose2d):
setattr(masked_model, name, MaskedConvTranspose2d.from_pretrained(module))
elif isinstance(module, nn.ConvTranspose3d) and not isinstance(module, MaskedConvTranspose3d):
setattr(masked_model, name, MaskedConvTranspose3d.from_pretrained(module))

# Mask nn.LayerNorm module
elif isinstance(module, nn.LayerNorm) and not isinstance(module, MaskedLayerNorm):
setattr(masked_model, name, MaskedLayerNorm.from_pretrained(module))

# Mask batch norm modules (1d, 2d, and 3d)
elif isinstance(module, nn.BatchNorm1d):
setattr(masked_model, name, MaskedBatchNorm1d.from_pretrained(module))
elif isinstance(module, nn.BatchNorm2d):
setattr(masked_model, name, MaskedBatchNorm2d.from_pretrained(module))
elif isinstance(module, nn.BatchNorm3d):
setattr(masked_model, name, MaskedBatchNorm3d.from_pretrained(module))
def replace_with_masked(module: nn.Module) -> None:
# Replace layers with their masked versions.
for name, child in module.named_children():
# Linear layers
if isinstance(child, nn.Linear) and not isinstance(child, MaskedLinear):
setattr(module, name, MaskedLinear.from_pretrained(child))
# 1d, 2d, 3d convolutional layers and transposed convolutional layers
elif isinstance(child, nn.Conv1d) and not isinstance(child, MaskedConv1d):
setattr(module, name, MaskedConv1d.from_pretrained(child))
elif isinstance(child, nn.Conv2d) and not isinstance(child, MaskedConv2d):
setattr(module, name, MaskedConv2d.from_pretrained(child))
elif isinstance(child, nn.Conv3d) and not isinstance(child, MaskedConv3d):
setattr(module, name, MaskedConv3d.from_pretrained(child))
elif isinstance(child, nn.ConvTranspose1d) and not isinstance(child, MaskedConvTranspose1d):
setattr(module, name, MaskedConvTranspose1d.from_pretrained(child))
elif isinstance(child, nn.ConvTranspose2d) and not isinstance(child, MaskedConvTranspose2d):
setattr(module, name, MaskedConvTranspose2d.from_pretrained(child))
elif isinstance(child, nn.ConvTranspose3d) and not isinstance(child, MaskedConvTranspose3d):
setattr(module, name, MaskedConvTranspose3d.from_pretrained(child))
# LayerNorm
elif isinstance(child, nn.LayerNorm) and not isinstance(child, MaskedLayerNorm):
setattr(module, name, MaskedLayerNorm.from_pretrained(child))
# 1d, 2d, and 3d BatchNorm
elif isinstance(child, nn.BatchNorm1d):
setattr(module, name, MaskedBatchNorm1d.from_pretrained(child))
elif isinstance(child, nn.BatchNorm2d):
setattr(module, name, MaskedBatchNorm2d.from_pretrained(child))
elif isinstance(child, nn.BatchNorm3d):
setattr(module, name, MaskedBatchNorm3d.from_pretrained(child))
# Recursively process the submodules of child
else:
replace_with_masked(child)

# Deepcopy the model to avoid modifying the original
masked_model = copy.deepcopy(original_model)
replace_with_masked(masked_model)
return masked_model


Expand Down
32 changes: 19 additions & 13 deletions tests/models/test_masked_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,22 @@ def test_masked_batch_norm_3d_from_pretrained() -> None:


def test_convert_to_masked_model() -> None:
model = CompositeConvNet()
masked_model = convert_to_masked_model(original_model=model)
assert isinstance(masked_model.conv1d, MaskedConv1d)
assert isinstance(masked_model.conv2d, MaskedConv2d)
assert isinstance(masked_model.conv3d, MaskedConv3d)
assert isinstance(masked_model.linear, MaskedLinear)
assert isinstance(masked_model.conv_transpose1d, MaskedConvTranspose1d)
assert isinstance(masked_model.conv_transpose2d, MaskedConvTranspose2d)
assert isinstance(masked_model.conv_transpose3d, MaskedConvTranspose3d)
assert isinstance(masked_model.bn1d, MaskedBatchNorm1d)
assert isinstance(masked_model.bn2d, MaskedBatchNorm2d)
assert isinstance(masked_model.bn3d, MaskedBatchNorm3d)
assert isinstance(masked_model.layer_norm, MaskedLayerNorm)
model1 = CompositeConvNet()
masked_model1 = convert_to_masked_model(original_model=model1)
assert isinstance(masked_model1.conv1d, MaskedConv1d)
assert isinstance(masked_model1.conv2d, MaskedConv2d)
assert isinstance(masked_model1.conv3d, MaskedConv3d)
assert isinstance(masked_model1.linear, MaskedLinear)
assert isinstance(masked_model1.conv_transpose1d, MaskedConvTranspose1d)
assert isinstance(masked_model1.conv_transpose2d, MaskedConvTranspose2d)
assert isinstance(masked_model1.conv_transpose3d, MaskedConvTranspose3d)
assert isinstance(masked_model1.bn1d, MaskedBatchNorm1d)
assert isinstance(masked_model1.bn2d, MaskedBatchNorm2d)
assert isinstance(masked_model1.bn3d, MaskedBatchNorm3d)
assert isinstance(masked_model1.layer_norm, MaskedLayerNorm)

# Test that convert_to_masked_model properly added the score parameters
# to all relevant modules by trying to load state_dict.
model2 = CompositeConvNet()
masked_model2 = convert_to_masked_model(model2)
masked_model1.load_state_dict(masked_model2.state_dict(), strict=True)

0 comments on commit 4a09d5a

Please sign in to comment.