Skip to content

Commit

Permalink
Extra arguments for instantiating convolutional layers
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>

Testing extra kwargs

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jul 21, 2023
1 parent abd5f12 commit 627b22a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
2 changes: 2 additions & 0 deletions simulai/models/_pytorch_models/_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def __init__(
scale: float = 1e-3,
devices: Union[str, list] = "cpu",
name: str = None,
**kwargs,
) -> None:
"""
Constructor method.
Expand Down Expand Up @@ -1172,6 +1173,7 @@ def __init__(
shallow=shallow,
use_batch_norm=use_batch_norm,
name=self.name,
**kwargs
)

self.encoder = encoder.to(self.device)
Expand Down
5 changes: 4 additions & 1 deletion simulai/templates/_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def cnn_autoencoder_auto(
use_batch_norm: bool = False,
shallow: bool = False,
name: str = None,
**kwargs,
) -> Tuple[NetworkTemplate, ...]:

"""
Expand Down Expand Up @@ -737,7 +738,7 @@ def cnn_autoencoder_auto(

autogen_cnn = NetworkInstanceGen(
architecture="cnn", dim=case, use_batch_norm=use_batch_norm,
kernel_size=kernel_size,
kernel_size=kernel_size, **kwargs,
)

autogen_dense = NetworkInstanceGen(architecture="dense", shallow=shallow)
Expand Down Expand Up @@ -798,6 +799,7 @@ def autoencoder_auto(
use_batch_norm: bool = False,
case: str = None,
name: str = None,
**kwargs,
) -> Tuple[Union[NetworkTemplate, None], ...]:

"""
Expand Down Expand Up @@ -864,6 +866,7 @@ def autoencoder_auto(
shallow=shallow,
use_batch_norm=use_batch_norm,
name=name,
**kwargs,
)

return encoder, decoder, bottleneck_encoder, bottleneck_decoder
Expand Down
2 changes: 2 additions & 0 deletions tests/network/test_template_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def test_autoencoder_kernel_size_shallow(self) -> None:
architecture="cnn",
case="2d",
shallow=True,
padding_mode='replicate',
)

estimated_data = autoencoder.eval(input_data=input_data)
Expand All @@ -392,6 +393,7 @@ def test_autoencoder_multiscaleautoencoder(self) -> None:
case="2d",
shallow=True,
name="model",
padding_mode='replicate',
)

estimated_data = autoencoder.reconstruction_forward(input_data=input_data)
Expand Down

0 comments on commit 627b22a

Please sign in to comment.