From 42872ecd583932b10db131d697c5dbb3e8e4d42d Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Mon, 4 Dec 2023 18:11:46 +0100 Subject: [PATCH] Torch approximator now uses correct device --- mushroom_rl/approximators/parametric/torch_approximator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mushroom_rl/approximators/parametric/torch_approximator.py b/mushroom_rl/approximators/parametric/torch_approximator.py index 3a6533de..ce742bea 100644 --- a/mushroom_rl/approximators/parametric/torch_approximator.py +++ b/mushroom_rl/approximators/parametric/torch_approximator.py @@ -49,6 +49,7 @@ def __init__(self, input_shape, output_shape, network, optimizer=None, loss=None self._n_fit_targets = n_fit_targets self.network = network(input_shape, output_shape, dropout=dropout, **params) + self.network.to(TorchUtils.get_device()) if self._dropout: self.network.eval()