Skip to content

Commit

Permalink
Torch approximator now uses correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 4, 2023
1 parent 74c4bf5 commit 42872ec
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions mushroom_rl/approximators/parametric/torch_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, input_shape, output_shape, network, optimizer=None, loss=None
self._n_fit_targets = n_fit_targets

self.network = network(input_shape, output_shape, dropout=dropout, **params)
self.network.to(TorchUtils.get_device())

if self._dropout:
self.network.eval()
Expand Down

0 comments on commit 42872ec

Please sign in to comment.