diff --git a/mushroom_rl/approximators/parametric/torch_approximator.py b/mushroom_rl/approximators/parametric/torch_approximator.py index 3a6533de..ce742bea 100644 --- a/mushroom_rl/approximators/parametric/torch_approximator.py +++ b/mushroom_rl/approximators/parametric/torch_approximator.py @@ -49,6 +49,7 @@ def __init__(self, input_shape, output_shape, network, optimizer=None, loss=None self._n_fit_targets = n_fit_targets self.network = network(input_shape, output_shape, dropout=dropout, **params) + self.network.to(TorchUtils.get_device()) if self._dropout: self.network.eval()