Skip to content

Commit

Permalink
Fixed type conversion with dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Nov 24, 2023
1 parent 036381f commit 1257ad0
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions mushroom_rl/core/_impl/type_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def to_backend_array(cls, array):
raise NotImplementedError

@staticmethod
def zeros(*dims):
def zeros(*dims, dtype):
raise NotImplementedError

@staticmethod
def ones(*dims):
def ones(*dims, dtype):
raise NotImplementedError


Expand All @@ -67,12 +67,12 @@ def to_backend_array(cls, array):
return cls.to_numpy(array)

@staticmethod
def zeros(*dims):
return np.zeros(dims)
def zeros(*dims, dtype=float):
return np.zeros(dims, dtype=dtype)

@staticmethod
def ones(*dims):
return np.ones(dims)
def ones(*dims, dtype=float):
return np.ones(dims, dtype=dtype)


class TorchConversion(DataConversion):
Expand All @@ -89,12 +89,12 @@ def to_backend_array(cls, array):
return cls.to_torch(array)

@staticmethod
def zeros(*dims):
return torch.zeros(*dims, device=TorchUtils.get_device())
def zeros(*dims, dtype=torch.float32):
return torch.zeros(*dims, dtype=dtype, device=TorchUtils.get_device())

@staticmethod
def ones(*dims):
return torch.ones(*dims, device=TorchUtils.get_device())
def ones(*dims, dtype=torch.float32):
return torch.ones(*dims, dtype=dtype, device=TorchUtils.get_device())


class ListConversion(DataConversion):
Expand All @@ -111,12 +111,12 @@ def to_backend_array(cls, array):
return cls.to_numpy(array)

@staticmethod
def zeros(*dims):
return np.zeros(dims)
def zeros(*dims, dtype=float):
return np.zeros(dims, dtype=float)

@staticmethod
def ones(*dims):
return np.ones(dims)
def ones(*dims, dtype=float):
return np.ones(dims, dtype=float)



Expand Down

0 comments on commit 1257ad0

Please sign in to comment.