diff --git a/mushroom_rl/core/__init__.py b/mushroom_rl/core/__init__.py index 0dfbe45a..7a2f4635 100644 --- a/mushroom_rl/core/__init__.py +++ b/mushroom_rl/core/__init__.py @@ -1,5 +1,5 @@ from .core import Core -from .dataset import Dataset +from .dataset import Dataset, VectorizedDataset from .environment import Environment, MDPInfo from .agent import Agent, AgentInfo from .serialization import Serializable diff --git a/mushroom_rl/core/_impl/__init__.py b/mushroom_rl/core/_impl/__init__.py index 9eeb5fbd..15858809 100644 --- a/mushroom_rl/core/_impl/__init__.py +++ b/mushroom_rl/core/_impl/__init__.py @@ -1,6 +1,6 @@ from .numpy_dataset import NumpyDataset from .torch_dataset import TorchDataset from .list_dataset import ListDataset -from .type_conversions import DataConversion, NumpyConversion, TorchConversion, ListConversion +from .array_backend import ArrayBackend, NumpyBackend, TorchBackend, ListBackend from .core_logic import CoreLogic from .vectorized_core_logic import VectorizedCoreLogic diff --git a/mushroom_rl/core/_impl/type_conversions.py b/mushroom_rl/core/_impl/array_backend.py similarity index 68% rename from mushroom_rl/core/_impl/type_conversions.py rename to mushroom_rl/core/_impl/array_backend.py index ad6828f3..ed82091d 100644 --- a/mushroom_rl/core/_impl/type_conversions.py +++ b/mushroom_rl/core/_impl/array_backend.py @@ -4,15 +4,19 @@ from mushroom_rl.utils.torch import TorchUtils -class DataConversion(object): +class ArrayBackend(object): @staticmethod - def get_converter(backend): + def get_backend_name(): + raise NotImplementedError + + @staticmethod + def get_array_backend(backend): if backend == 'numpy': - return NumpyConversion + return NumpyBackend elif backend == 'torch': - return TorchConversion + return TorchBackend else: - return ListConversion + return ListBackend @classmethod def convert(cls, *arrays, to='numpy'): @@ -55,8 +59,16 @@ def ones(*dims, dtype): def copy(array): raise NotImplementedError + @staticmethod + def pack_padded_sequence(array, lengths): + raise NotImplementedError + + +class NumpyBackend(ArrayBackend): + @staticmethod + def get_backend_name(): + return 'numpy' -class NumpyConversion(DataConversion): @staticmethod def to_numpy(array): return array @@ -81,8 +93,20 @@ def ones(*dims, dtype=float): def copy(array): return array.copy() + @staticmethod + def pack_padded_sequence(array, lengths): + shape = array.shape + + new_shape = (shape[0] * shape[1],) + shape[2:] + mask = (np.arange(len(array))[:, None] < lengths[None, :]).flatten() + return array.reshape(new_shape)[mask] + + +class TorchBackend(ArrayBackend): + @staticmethod + def get_backend_name(): + return 'torch' -class TorchConversion(DataConversion): @staticmethod def to_numpy(array): return None if array is None else array.detach().cpu().numpy() @@ -107,8 +131,20 @@ def ones(*dims, dtype=torch.float32): def copy(array): return array.clone() + @staticmethod + def pack_padded_sequence(array, lengths): + shape = array.shape + + new_shape = (shape[0]*shape[1], ) + shape[2:] + mask = (torch.arange(len(array), device=TorchUtils.get_device())[None, :] < lengths[:, None]).flatten() + return array.reshape(new_shape)[mask] + + +class ListBackend(ArrayBackend): + @staticmethod + def get_backend_name(): + return 'list' -class ListConversion(DataConversion): @staticmethod def to_numpy(array): return np.array(array) @@ -133,6 +169,10 @@ def ones(*dims, dtype=float): def copy(array): return array.copy() + @staticmethod + def pack_padded_sequence(array, lengths): + return NumpyBackend.pack_padded_sequence(array, lengths) + diff --git a/mushroom_rl/core/_impl/list_dataset.py b/mushroom_rl/core/_impl/list_dataset.py index b2853541..03862dab 100644 --- a/mushroom_rl/core/_impl/list_dataset.py +++ b/mushroom_rl/core/_impl/list_dataset.py @@ -101,6 +101,10 @@ def policy_state(self): def policy_next_state(self): return [step[7] for step in self._dataset] + @property + def is_stateful(self): + return self._is_stateful + @property def n_episodes(self): n_episodes = 0 diff --git a/mushroom_rl/core/_impl/numpy_dataset.py b/mushroom_rl/core/_impl/numpy_dataset.py index d0de9f6f..98af1b0b 100644 --- a/mushroom_rl/core/_impl/numpy_dataset.py +++ b/mushroom_rl/core/_impl/numpy_dataset.py @@ -4,8 +4,8 @@ class NumpyDataset(Serializable): - def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape): - flags_len = action_shape[0] + def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape, + policy_state_shape): self._state_type = state_type self._action_type = action_type @@ -14,8 +14,8 @@ def __init__(self, state_type, state_shape, action_type, action_shape, reward_sh self._actions = np.empty(action_shape, dtype=self._action_type) self._rewards = np.empty(reward_shape, dtype=float) self._next_states = np.empty(state_shape, dtype=self._state_type) - self._absorbing = np.empty(flags_len, dtype=bool) - self._last = np.empty(flags_len, dtype=bool) + self._absorbing = np.empty(flag_shape, dtype=bool) + self._last = np.empty(flag_shape, dtype=bool) self._len = 0 if policy_state_shape is None: @@ -100,7 +100,7 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat self._absorbing[i] = absorbing self._last[i] = last - if self._is_stateful: + if self.is_stateful: self._policy_states[i] = policy_state self._policy_next_states[i] = policy_next_state @@ -114,7 +114,7 @@ def clear(self): self._absorbing = np.empty_like(self._absorbing) self._last = np.empty_like(self._last) - if self._is_stateful: + if self.is_stateful: self._policy_states = np.empty_like(self._policy_states) self._policy_next_states = np.empty_like(self._policy_next_states) @@ -131,7 +131,7 @@ def get_view(self, index): view._last = self.last[index, ...] view._len = view._states.shape[0] - if self._is_stateful: + if self.is_stateful: view._policy_states = self._policy_states[index, ...] view._policy_next_states = self._policy_next_states[index, ...] @@ -153,7 +153,7 @@ def __add__(self, other): result._last[len(self)-1] = True result._len = len(self) + len(other) - if self._is_stateful: + if self.is_stateful: result._policy_states = np.concatenate((self.policy_state, other.policy_state)) result._policy_next_states = np.concatenate((self.policy_next_state, other.policy_next_state)) @@ -192,7 +192,7 @@ def policy_next_state(self): return self._policy_next_states[:len(self)] @property - def _is_stateful(self): + def is_stateful(self): return self._policy_states is not None @property diff --git a/mushroom_rl/core/_impl/torch_dataset.py b/mushroom_rl/core/_impl/torch_dataset.py index fd8d6e68..0d0e732f 100644 --- a/mushroom_rl/core/_impl/torch_dataset.py +++ b/mushroom_rl/core/_impl/torch_dataset.py @@ -5,9 +5,8 @@ class TorchDataset(Serializable): - def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape): - flags_len = action_shape[0] - + def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape, + policy_state_shape): self._state_type = state_type self._action_type = action_type @@ -15,8 +14,8 @@ def __init__(self, state_type, state_shape, action_type, action_shape, reward_sh self._actions = torch.empty(*action_shape, dtype=self._action_type, device=TorchUtils.get_device()) self._rewards = torch.empty(*reward_shape, dtype=torch.float, device=TorchUtils.get_device()) self._next_states = torch.empty(*state_shape, dtype=self._state_type, device=TorchUtils.get_device()) - self._absorbing = torch.empty(flags_len, dtype=torch.bool, device=TorchUtils.get_device()) - self._last = torch.empty(flags_len, dtype=torch.bool, device=TorchUtils.get_device()) + self._absorbing = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device()) + self._last = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device()) self._len = 0 if policy_state_shape is None: @@ -101,7 +100,7 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat self._absorbing[i] = absorbing self._last[i] = last - if self._is_stateful: + if self.is_stateful: self._policy_states[i] = policy_state self._policy_next_states[i] = policy_next_state @@ -115,7 +114,7 @@ def clear(self): self._absorbing = torch.empty_like(self._absorbing) self._last = torch.empty_like(self._last) - if self._is_stateful: + if self.is_stateful: self._policy_states = torch.empty_like(self._policy_states) self._policy_next_states = torch.empty_like(self._policy_next_states) @@ -132,7 +131,7 @@ def get_view(self, index): view._last = self._last[index, ...] view._len = view._states.shape[0] - if self._is_stateful: + if self.is_stateful: view._policy_states = self._policy_states[index, ...] view._policy_next_states = self._policy_next_states[index, ...] @@ -154,12 +153,13 @@ def __add__(self, other): result._last[len(self) - 1] = True result._len = len(self) + len(other) - if self._is_stateful: + if self.is_stateful: result._policy_states = torch.concatenate((self.policy_state, other.policy_state)) result._policy_next_states = torch.concatenate((self.policy_next_state, other.policy_next_state)) return result + @property def state(self): return self._states[:len(self)] @@ -193,7 +193,7 @@ def policy_next_state(self): return self._policy_next_states[:len(self)] @property - def _is_stateful(self): + def is_stateful(self): return self._policy_states is not None @property diff --git a/mushroom_rl/core/_impl/vectorized_core_logic.py b/mushroom_rl/core/_impl/vectorized_core_logic.py index 809f0a9c..1b5b270c 100644 --- a/mushroom_rl/core/_impl/vectorized_core_logic.py +++ b/mushroom_rl/core/_impl/vectorized_core_logic.py @@ -1,17 +1,17 @@ -from .type_conversions import DataConversion +from .array_backend import ArrayBackend from .core_logic import CoreLogic class VectorizedCoreLogic(CoreLogic): def __init__(self, backend, n_envs): - self._converter = DataConversion.get_converter(backend) + self._array_backend = ArrayBackend.get_array_backend(backend) self._n_envs = n_envs - self._running_envs = self._converter.zeros(n_envs, dtype=bool) + self._running_envs = self._array_backend.zeros(n_envs, dtype=bool) super().__init__() def get_mask(self, last): - mask = self._converter.ones(self._n_envs, dtype=bool) + mask = self._array_backend.ones(self._n_envs, dtype=bool) terminated_episodes = (last & self._running_envs).sum() running_episodes = (~last & self._running_envs).sum() @@ -29,11 +29,11 @@ def get_mask(self, last): missing_episodes_fit = self._n_episodes_per_fit - self._current_episodes_counter - running_episodes max_runs = min(missing_episodes_fit, max_runs) - new_mask = self._converter.ones(terminated_episodes, dtype=bool) + new_mask = self._array_backend.ones(terminated_episodes, dtype=bool) new_mask[max_runs:] = False mask[last] = new_mask - self._running_envs = self._converter.copy(mask) + self._running_envs = self._array_backend.copy(mask) return mask @@ -59,12 +59,12 @@ def after_step(self, last): def after_fit(self): super().after_fit() if self._n_episodes_per_fit is not None: - self._running_envs = self._converter.zeros(self._n_envs, dtype=bool) + self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool) def _reset_counters(self): super()._reset_counters() - self._running_envs = self._converter.zeros(self._n_envs, dtype=bool) + self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool) @property def converter(self): - return self._converter + return self._array_backend diff --git a/mushroom_rl/core/agent.py b/mushroom_rl/core/agent.py index 136edf87..11db34bb 100644 --- a/mushroom_rl/core/agent.py +++ b/mushroom_rl/core/agent.py @@ -48,8 +48,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'): self.policy = policy self.next_action = None - self._agent_converter = DataConversion.get_converter(backend) - self._env_converter = DataConversion.get_converter(self.mdp_info.backend) + self._agent_backend = ArrayBackend.get_array_backend(backend) + self._env_backend = ArrayBackend.get_array_backend(self.mdp_info.backend) self._preprocessors = list() @@ -60,8 +60,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'): next_action='none', mdp_info='mushroom', _info='mushroom', - _agent_converter='primitive', - _env_converter='primitive', + _agent_backend='primitive', + _env_backend='primitive', _preprocessors='mushroom', _logger='none' ) @@ -146,10 +146,10 @@ def preprocessors(self): return self._preprocessors def _convert_to_env_backend(self, array): - return self._env_converter.to_backend_array(self._agent_converter, array) + return self._env_backend.to_backend_array(self._agent_backend, array) def _convert_to_agent_backend(self, array): - return self._agent_converter.to_backend_array(self._env_converter, array) + return self._agent_backend.to_backend_array(self._env_backend, array) @property def info(self): diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index 2250e67d..e1182dcd 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -8,9 +8,11 @@ class Dataset(Serializable): - def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None): + def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1): assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None) + self._array_backend = ArrayBackend.get_array_backend(mdp_info.backend) + if n_steps is not None: n_samples = n_steps else: @@ -19,12 +21,17 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None): n_samples = horizon * n_episodes - state_shape = (n_samples,) + mdp_info.observation_space.shape - action_shape = (n_samples,) + mdp_info.action_space.shape - reward_shape = (n_samples,) + if n_envs == 1: + base_shape = (n_samples,) + else: + base_shape = (n_samples, n_envs) + + state_shape = base_shape + mdp_info.observation_space.shape + action_shape = base_shape + mdp_info.action_space.shape + reward_shape = base_shape if agent_info.is_stateful: - policy_state_shape = (n_samples,) + agent_info.policy_state_shape + policy_state_shape = base_shape + agent_info.policy_state_shape else: policy_state_shape = None @@ -36,16 +43,14 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None): self._theta_list = list() if mdp_info.backend == 'numpy': - self._data = NumpyDataset(state_type, state_shape, action_type, action_shape, reward_shape, + self._data = NumpyDataset(state_type, state_shape, action_type, action_shape, reward_shape, base_shape, policy_state_shape) elif mdp_info.backend == 'torch': - self._data = TorchDataset(state_type, state_shape, action_type, action_shape, reward_shape, + self._data = TorchDataset(state_type, state_shape, action_type, action_shape, reward_shape, base_shape, policy_state_shape) else: self._data = ListDataset(policy_state_shape is not None) - self._converter = DataConversion.get_converter(mdp_info.backend) - self._gamma = mdp_info.gamma self._add_save_attr( @@ -53,7 +58,7 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None): _episode_info='pickle', _theta_list='pickle', _data='mushroom', - _converter='primitive', + _array_backend='primitive', _gamma='primitive', ) @@ -108,13 +113,13 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, if backend == 'numpy': dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) - dataset._converter = NumpyConversion + dataset._array_backend = NumpyBackend elif backend == 'torch': dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) - dataset._converter = TorchConversion + dataset._array_backend = TorchBackend else: dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) - dataset._converter = ListConversion + dataset._array_backend = ListBackend dataset._add_save_attr( _info='pickle', @@ -279,8 +284,8 @@ def parse(self, to='numpy'): A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last """ - return self._converter.convert(self.state, self.action, self.reward, self.next_state, - self.absorbing, self.last, to=to) + return self._array_backend.convert(self.state, self.action, self.reward, self.next_state, + self.absorbing, self.last, to=to) def parse_policy_state(self, to='numpy'): """ @@ -292,7 +297,7 @@ def parse_policy_state(self, to='numpy'): A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last """ - return self._converter.convert(self.policy_state, self.policy_next_state, to=to) + return self._array_backend.convert(self.policy_state, self.policy_next_state, to=to) def select_first_episodes(self, n_episodes): """ @@ -424,3 +429,54 @@ def _merge_info(info, other_info): for key in info.keys(): new_info[key] = info[key] + other_info[key] return new_info + + +class VectorizedDataset(Dataset): + def __init__(self, mdp_info, agent_info, n_envs, n_steps=None, n_episodes=None): + super().__init__(mdp_info, agent_info, n_steps, n_episodes, n_envs) + + self._length = self._array_backend.zeros(n_envs, dtype=int) + + self._add_save_attr( + _length=mdp_info.backend + ) + + def append_vectorized(self, step, info, mask): + self.append(step, {}) # FIXME!!! + + self._length[mask] += 1 + + def clear(self): + super().clear() + + self._length = self._array_backend.zeros(len(self._length), dtype=int) + + def flatten(self): + if len(self) == 0: + return None + + states = self._array_backend.pack_padded_sequence(self._data.state, self._length) + actions = self._array_backend.pack_padded_sequence(self._data.action, self._length) + rewards = self._array_backend.pack_padded_sequence(self._data.reward, self._length) + next_states = self._array_backend.pack_padded_sequence(self._data.next_state, self._length) + absorbings = self._array_backend.pack_padded_sequence(self._data.absorbing, self._length) + + last_padded = self._data.last + last_padded[self._length-1, :] = True + lasts = self._array_backend.pack_padded_sequence(last_padded, self._length) + + policy_state = None + policy_next_state = None + + if self._data.is_stateful: + policy_state = self._array_backend.pack_padded_sequence(self._data.policy_state, self._length) + policy_next_state = self._array_backend.pack_padded_sequence(self._data.policy_next_state, self._length) + + return self.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state=policy_state, policy_next_state=policy_next_state, + info=None, episode_info=None, theta_list=None, # FIXME!!! + gamma=self._gamma, backend=self._array_backend.get_backend_name()) + + + + diff --git a/mushroom_rl/core/vectorized_core.py b/mushroom_rl/core/vectorized_core.py index e559b136..6efe2df2 100644 --- a/mushroom_rl/core/vectorized_core.py +++ b/mushroom_rl/core/vectorized_core.py @@ -1,4 +1,4 @@ -from mushroom_rl.core.dataset import Dataset +from mushroom_rl.core.dataset import VectorizedDataset from mushroom_rl.utils.record import VideoRecorder from ._impl import VectorizedCoreLogic @@ -64,10 +64,10 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_ assert (render and record) or (not record), "To record, the render flag must be set to true" self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) - datasets = [Dataset(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) - for _ in range(self.env.number)] + dataset = VectorizedDataset(self.env.info, self.agent.info, self.env.number, + n_steps_per_fit, n_episodes_per_fit) - self._run(datasets, n_steps, n_episodes, render, quiet, record) + self._run(dataset, n_steps, n_episodes, render, quiet, record) def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=False, quiet=False, record=False): """ @@ -93,12 +93,11 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa self._core_logic.initialize_evaluate() n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes - datasets = [Dataset(self.env.info, self.agent.info, n_steps, n_episodes_dataset) - for _ in range(self.env.number)] + dataset = VectorizedDataset(self.env.info, self.agent.info, self.env.number, n_steps, n_episodes_dataset) - return self._run(datasets, n_steps, n_episodes, render, quiet, record, initial_states) + return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) - def _run(self, datasets, n_steps, n_episodes, render, quiet, record, initial_states=None): + def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_states=None): self._core_logic.initialize_run(n_steps, n_episodes, initial_states, quiet) last = self._core_logic.converter.ones(self.env.number, dtype=bool) @@ -114,18 +113,17 @@ def _run(self, datasets, n_steps, n_episodes, render, quiet, record, initial_sta self.callback_step(samples) self._core_logic.after_step(samples[5] & mask) - self._add_to_dataset(mask, datasets, samples, step_infos) + dataset.append_vectorized(samples, step_infos, mask) if self._core_logic.fit_required(): - fit_dataset = self._aggregate(datasets) + fit_dataset = dataset.flatten() self.agent.fit(fit_dataset) self._core_logic.after_fit() for c in self.callbacks_fit: - c(datasets) + c(dataset) - for dataset in datasets: - dataset.clear() + dataset.clear() last = samples[5] @@ -134,14 +132,7 @@ def _run(self, datasets, n_steps, n_episodes, render, quiet, record, initial_sta self._end(record) - return self._aggregate(datasets) - - def _add_to_dataset(self, action_mask, datasets, samples, step_infos): - for i in range(self.env.number): - if action_mask[i]: - sample = (samples[0][i], samples[1][i], samples[2][i], samples[3][i], samples[4][i], samples[5][i]) - step_info = step_infos[i] - datasets[i].append(sample, step_info) + return dataset.flatten() def _step(self, render, record, mask): """ @@ -226,22 +217,6 @@ def _preprocess(self, states): return states - @staticmethod - def _aggregate(datasets): - aggregated_dataset = None - for dataset in datasets: - if len(dataset) > 0: - aggregated_dataset = dataset - break - - if aggregated_dataset is not None and len(aggregated_dataset) > 0: - for dataset in datasets[1:]: - aggregated_dataset += dataset - - return aggregated_dataset - else: - return None - def _build_recorder_class(self, recorder_class=None, fps=None, **kwargs): """ Method to create a video recorder class. diff --git a/mushroom_rl/features/_implementations/torch_features.py b/mushroom_rl/features/_implementations/torch_features.py index b00008dd..7129903a 100644 --- a/mushroom_rl/features/_implementations/torch_features.py +++ b/mushroom_rl/features/_implementations/torch_features.py @@ -17,7 +17,7 @@ def __call__(self, *args): y_list = [self._phi[i].forward(x) for i in range(len(self._phi))] y = torch.cat(y_list, 1).squeeze() - y = y.detach().numpy() + y = y.detach().cpu().numpy() if y.shape[0] == 1: return y[0] diff --git a/mushroom_rl/features/tensors/basis_tensor.py b/mushroom_rl/features/tensors/basis_tensor.py index 922f50dc..1f72131c 100644 --- a/mushroom_rl/features/tensors/basis_tensor.py +++ b/mushroom_rl/features/tensors/basis_tensor.py @@ -33,6 +33,8 @@ def __init__(self, mu, scale, dim=None, normalized=False): self._normalized = normalized + super().__init__() + def forward(self, x): if self._dim is not None: x = torch.index_select(x, 1, self._dim) diff --git a/mushroom_rl/features/tensors/constant_tensor.py b/mushroom_rl/features/tensors/constant_tensor.py index b348d8c0..b1e0cfd5 100644 --- a/mushroom_rl/features/tensors/constant_tensor.py +++ b/mushroom_rl/features/tensors/constant_tensor.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from mushroom_rl.utils.torch import TorchUtils + class ConstantTensor(nn.Module): """ @@ -9,7 +11,7 @@ class ConstantTensor(nn.Module): """ def forward(self, x): - return torch.ones(x.shape[0], 1) + return torch.ones(x.shape[0], 1).to(TorchUtils.get_device()) @property def size(self): diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 8daaece1..e669b415 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -77,6 +77,8 @@ def test_serialization_cpu_cuda(tmpdir): assert a == b + TorchUtils.set_default_device('cpu') + diff --git a/tests/core/test_vectorized_envs.py b/tests/core/test_vectorized_envs.py index 0b956616..fe17167f 100644 --- a/tests/core/test_vectorized_envs.py +++ b/tests/core/test_vectorized_envs.py @@ -55,7 +55,12 @@ def __init__(self, backend): super().__init__(mdp_info, n_envs) def reset_all(self, env_mask, state=None): - self._state[env_mask] = torch.randint(size=(env_mask.sum(), self._state.shape[1]), low=2, high=200).float().to(TorchUtils.get_device()) + if self.info.backend == 'torch': + self._state[env_mask] = torch.randint(size=(env_mask.sum(), self._state.shape[1]), + low=2, high=200).float().to(TorchUtils.get_device()) + elif self.info.backend == 'numpy': + self._state[env_mask] = np.random.randint(size=(env_mask.sum(), self._state.shape[1]), + low=2, high=200).astype(float) return self._state, [{}]*self._n_envs def step_all(self, env_mask, action): @@ -101,4 +106,5 @@ def test_vectorized_env_(): TorchUtils.set_default_device('cuda') run_exp(env_backend='torch', agent_backend='torch') run_exp(env_backend='torch', agent_backend='numpy') + TorchUtils.set_default_device('cpu')