You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
remove_pt_head from mace.tools.scripts_utils returns a float32 model even if the input model is float64. Expected behavior is to retain the precision of the input model.
Running
import torch, mace
from mace.tools.scripts_utils import remove_pt_head
print(f"torch {torch.__version__}, mace {mace.__version__}")
model = torch.load("MACE.model")
print(f"Input model type: {next(model.parameters()).dtype}")
model_single = remove_pt_head(model, None)
print(f"Output model type: {next(model_single.parameters()).dtype}")
Produces
torch 2.5.1+cpu, mace 0.3.9
Input model type: torch.float64
Output model type: torch.float32
I've checked the script and it doesn't specifically down cast the model. Instead it returns a model in torch.get_default_dtype(). If you add torch.set_default_dtype(torch.float64), the output model should be in float64.
I.e.:
import torch, mace
torch.set_default_dtype(torch.float64)
from mace.tools.scripts_utils import remove_pt_head
print(f"torch {torch.__version__}, mace {mace.__version__}")
model = torch.load("MACE.model")
print(f"Input model type: {next(model.parameters()).dtype}")
model_single = remove_pt_head(model, None)
print(f"Output model type: {next(model_single.parameters()).dtype}")
Otherwise, you can just do model.double() to convert to float64.
remove_pt_head
frommace.tools.scripts_utils
returns afloat32
model even if the input model isfloat64
. Expected behavior is to retain the precision of the input model.Running
Produces
Using commit fece538
The text was updated successfully, but these errors were encountered: