diff --git a/mushroom_rl/core/_impl/type_conversions.py b/mushroom_rl/core/_impl/type_conversions.py index 01c1609f..02a61527 100644 --- a/mushroom_rl/core/_impl/type_conversions.py +++ b/mushroom_rl/core/_impl/type_conversions.py @@ -45,11 +45,11 @@ def to_backend_array(cls, array): raise NotImplementedError @staticmethod - def zeros(*dims): + def zeros(*dims, dtype): raise NotImplementedError @staticmethod - def ones(*dims): + def ones(*dims, dtype): raise NotImplementedError @@ -67,12 +67,12 @@ def to_backend_array(cls, array): return cls.to_numpy(array) @staticmethod - def zeros(*dims): - return np.zeros(dims) + def zeros(*dims, dtype=float): + return np.zeros(dims, dtype=dtype) @staticmethod - def ones(*dims): - return np.ones(dims) + def ones(*dims, dtype=float): + return np.ones(dims, dtype=dtype) class TorchConversion(DataConversion): @@ -89,12 +89,12 @@ def to_backend_array(cls, array): return cls.to_torch(array) @staticmethod - def zeros(*dims): - return torch.zeros(*dims, device=TorchUtils.get_device()) + def zeros(*dims, dtype=torch.float32): + return torch.zeros(*dims, dtype=dtype, device=TorchUtils.get_device()) @staticmethod - def ones(*dims): - return torch.ones(*dims, device=TorchUtils.get_device()) + def ones(*dims, dtype=torch.float32): + return torch.ones(*dims, dtype=dtype, device=TorchUtils.get_device()) class ListConversion(DataConversion): @@ -111,12 +111,12 @@ def to_backend_array(cls, array): return cls.to_numpy(array) @staticmethod - def zeros(*dims): - return np.zeros(dims) + def zeros(*dims, dtype=float): + return np.zeros(dims, dtype=float) @staticmethod - def ones(*dims): - return np.ones(dims) + def ones(*dims, dtype=float): + return np.ones(dims, dtype=float)