diff --git a/simulai/models/_pytorch_models/_autoencoder.py b/simulai/models/_pytorch_models/_autoencoder.py index 2026f6c9..8d5e42c7 100644 --- a/simulai/models/_pytorch_models/_autoencoder.py +++ b/simulai/models/_pytorch_models/_autoencoder.py @@ -17,6 +17,7 @@ import numpy as np import torch +from simulai import ARRAY_DTYPE from simulai.regression import ConvolutionalNetwork, DenseNetwork, Linear from simulai.templates import ( NetworkTemplate, @@ -508,7 +509,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray """ if isinstance(input_data, np.ndarray): - input_data = torch.from_numpy(input_data.astype("float32")) + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) input_data = input_data.to(self.device) @@ -995,7 +996,7 @@ def predict( """ if isinstance(input_data, np.ndarray): - input_data = torch.from_numpy(input_data.astype("float32")) + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) predictions = list() latent = self.projection(input_data=input_data) @@ -1694,7 +1695,7 @@ def project(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndar >>> projected_data = autoencoder.project(input_data=input_data) """ if isinstance(input_data, np.ndarray): - input_data = torch.from_numpy(input_data.astype("float32")) + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) input_data = input_data.to(self.device) @@ -1725,7 +1726,7 @@ def reconstruct( >>> reconstructed_data = autoencoder.reconstruct(input_data=input_data) """ if isinstance(input_data, np.ndarray): - input_data = torch.from_numpy(input_data.astype("float32")) + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) input_data = input_data.to(self.device) @@ -1754,7 +1755,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray >>> reconstructed_data = autoencoder.eval(input_data=input_data) """ if isinstance(input_data, np.ndarray): - input_data = torch.from_numpy(input_data.astype("float32")) + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) input_data = input_data.to(self.device) diff --git a/tests/PINN/test_deep_operator_pinn.py b/tests/PINN/test_deep_operator_pinn.py index 9dffb573..85e750e7 100644 --- a/tests/PINN/test_deep_operator_pinn.py +++ b/tests/PINN/test_deep_operator_pinn.py @@ -16,6 +16,9 @@ import numpy as np +from tests.config import configure_dtype +torch = configure_dtype() + from simulai.optimization import Optimizer from simulai.residuals import SymbolicOperator diff --git a/tests/PINN/test_vanilla_pinn.py b/tests/PINN/test_vanilla_pinn.py index 6c6f5bc0..98679611 100644 --- a/tests/PINN/test_vanilla_pinn.py +++ b/tests/PINN/test_vanilla_pinn.py @@ -17,6 +17,9 @@ import matplotlib.pyplot as plt import numpy as np +from tests.config import configure_dtype +torch = configure_dtype() + from simulai.optimization import Optimizer from simulai.residuals import SymbolicOperator diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 00000000..429c3a73 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,40 @@ +# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# (C) Copyright IBM Corporation 2017, 2018, 2019 +# U.S. Government Users Restricted Rights: Use, duplication or disclosure restricted +# by GSA ADP Schedule Contract with IBM Corp. +# +# Author: Joao Lucas S. Almeida + +import os +import torch + +def configure_dtype(): + + test_dtype_var = os.environ.get("TEST_DTYPE") + + if test_dtype_var is not None: + test_dtype = getattr(torch, test_dtype_var) + else: + test_dtype = torch.float32 + + torch.set_default_dtype(test_dtype) + + print(f"Using dtype {test_dtype} in tests.") + + return torch + + + diff --git a/tests/metrics/test_mahalanobis.py b/tests/metrics/test_mahalanobis.py index 6a726612..8ba6e325 100644 --- a/tests/metrics/test_mahalanobis.py +++ b/tests/metrics/test_mahalanobis.py @@ -1,6 +1,6 @@ from unittest import TestCase - -import torch +from tests.config import configure_dtype +torch = configure_dtype() from simulai.metrics import MahalanobisDistance diff --git a/tests/metrics/test_pointwise.py b/tests/metrics/test_pointwise.py index 68d1371a..45acab1f 100644 --- a/tests/metrics/test_pointwise.py +++ b/tests/metrics/test_pointwise.py @@ -15,7 +15,9 @@ from unittest import TestCase import numpy as np -import torch + +from tests.config import configure_dtype +torch = configure_dtype() from simulai.metrics import PointwiseError diff --git a/tests/network/test_conv_1d.py b/tests/network/test_conv_1d.py index 7295c7c4..f1ccfad4 100644 --- a/tests/network/test_conv_1d.py +++ b/tests/network/test_conv_1d.py @@ -16,9 +16,12 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from utils import configure_device +from simulai import ARRAY_DTYPE from simulai.file import SPFile from simulai.optimization import Optimizer @@ -34,8 +37,8 @@ def generate_data( input_data = np.random.rand(n_samples, n_inputs, vector_size) output_data = np.random.rand(n_samples, n_outputs) - return torch.from_numpy(input_data.astype("float32")), torch.from_numpy( - output_data.astype("float32") + return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy( + output_data.astype(ARRAY_DTYPE) ) diff --git a/tests/network/test_conv_2d.py b/tests/network/test_conv_2d.py index 53d1b3f2..54e4367d 100644 --- a/tests/network/test_conv_2d.py +++ b/tests/network/test_conv_2d.py @@ -16,9 +16,12 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from utils import configure_device +from simulai import ARRAY_DTYPE from simulai.file import SPFile from simulai.optimization import Optimizer @@ -34,8 +37,8 @@ def generate_data( input_data = np.random.rand(n_samples, n_inputs, *image_size) output_data = np.random.rand(n_samples, n_outputs) - return torch.from_numpy(input_data.astype("float32")), torch.from_numpy( - output_data.astype("float32") + return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy( + output_data.astype(ARRAY_DTYPE) ) diff --git a/tests/network/test_deeponet.py b/tests/network/test_deeponet.py index 5b4d2c2a..69aeda2d 100644 --- a/tests/network/test_deeponet.py +++ b/tests/network/test_deeponet.py @@ -15,7 +15,8 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() from utils import configure_device DEVICE = configure_device() diff --git a/tests/network/test_flexible_deeponet.py b/tests/network/test_flexible_deeponet.py index 279d4162..71ec3ac5 100644 --- a/tests/network/test_flexible_deeponet.py +++ b/tests/network/test_flexible_deeponet.py @@ -15,7 +15,9 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from utils import configure_device DEVICE = configure_device() diff --git a/tests/network/test_improved_deeponet.py b/tests/network/test_improved_deeponet.py index e66d6fec..6326b1ca 100644 --- a/tests/network/test_improved_deeponet.py +++ b/tests/network/test_improved_deeponet.py @@ -15,7 +15,9 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from utils import configure_device DEVICE = configure_device() diff --git a/tests/network/test_residual_cnn.py b/tests/network/test_residual_cnn.py index f9e9caf3..b5f32300 100644 --- a/tests/network/test_residual_cnn.py +++ b/tests/network/test_residual_cnn.py @@ -19,7 +19,9 @@ import matplotlib.pyplot as plt import numpy as np -import torch + +from tests.config import configure_dtype +torch = configure_dtype() torch.autograd.set_detect_anomaly(True) diff --git a/tests/network/test_template_gen.py b/tests/network/test_template_gen.py index 0f014cd5..28c7b5d1 100644 --- a/tests/network/test_template_gen.py +++ b/tests/network/test_template_gen.py @@ -16,11 +16,14 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from utils import configure_device DEVICE = configure_device() +from simulai import ARRAY_DTYPE def generate_data_2d( n_samples: int = None, @@ -31,8 +34,8 @@ def generate_data_2d( input_data = np.random.rand(n_samples, n_inputs, *image_size) output_data = np.random.rand(n_samples, n_outputs) - return torch.from_numpy(input_data.astype("float32")), torch.from_numpy( - output_data.astype("float32") + return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy( + output_data.astype(ARRAY_DTYPE) ) @@ -45,8 +48,8 @@ def generate_data_1d( input_data = np.random.rand(n_samples, n_inputs, vector_size) output_data = np.random.rand(n_samples, n_outputs) - return torch.from_numpy(input_data.astype("float32")), torch.from_numpy( - output_data.astype("float32") + return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy( + output_data.astype(ARRAY_DTYPE) ) diff --git a/tests/network/utils.py b/tests/network/utils.py index cb040774..6a273770 100644 --- a/tests/network/utils.py +++ b/tests/network/utils.py @@ -7,7 +7,8 @@ def configure_device(): if not simulai_network_gpu: device = "cpu" else: - import torch + from tests.config import configure_dtype + torch = configure_dtype() if not torch.cuda.is_available(): raise Exception("There is no gpu available to execute the tests.") diff --git a/tests/residuals/test_symbolicoperator.py b/tests/residuals/test_symbolicoperator.py index 75a2ec04..1ca6453c 100644 --- a/tests/residuals/test_symbolicoperator.py +++ b/tests/residuals/test_symbolicoperator.py @@ -16,7 +16,9 @@ from unittest import TestCase import numpy as np -import torch +from tests.config import configure_dtype +torch = configure_dtype() + from simulai.residuals import SymbolicOperator diff --git a/tests/rom/test_cnn_autoencoder.py b/tests/rom/test_cnn_autoencoder.py index 5670e1ce..73e01738 100644 --- a/tests/rom/test_cnn_autoencoder.py +++ b/tests/rom/test_cnn_autoencoder.py @@ -3,6 +3,9 @@ import numpy as np +from tests.config import configure_dtype +torch = configure_dtype() + from simulai.file import SPFile from simulai.optimization import Optimizer