Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove_pt_head downconverts to float32 #760

Closed
mwswift-nrl opened this issue Dec 20, 2024 · 3 comments
Closed

remove_pt_head downconverts to float32 #760

mwswift-nrl opened this issue Dec 20, 2024 · 3 comments

Comments

@mwswift-nrl
Copy link

remove_pt_head from mace.tools.scripts_utils returns a float32 model even if the input model is float64. Expected behavior is to retain the precision of the input model.

Running

import torch, mace
from mace.tools.scripts_utils import remove_pt_head
print(f"torch {torch.__version__}, mace {mace.__version__}")
model = torch.load("MACE.model")
print(f"Input model type: {next(model.parameters()).dtype}")
model_single = remove_pt_head(model, None)
print(f"Output model type: {next(model_single.parameters()).dtype}")

Produces

torch 2.5.1+cpu, mace 0.3.9
Input model type: torch.float64
Output model type: torch.float32

Using commit fece538

@RokasEl
Copy link
Collaborator

RokasEl commented Dec 21, 2024

I've checked the script and it doesn't specifically down cast the model. Instead it returns a model in torch.get_default_dtype(). If you add torch.set_default_dtype(torch.float64), the output model should be in float64.

I.e.:

import torch, mace
torch.set_default_dtype(torch.float64)
from mace.tools.scripts_utils import remove_pt_head
print(f"torch {torch.__version__}, mace {mace.__version__}")
model = torch.load("MACE.model")
print(f"Input model type: {next(model.parameters()).dtype}")
model_single = remove_pt_head(model, None)
print(f"Output model type: {next(model_single.parameters()).dtype}")

Otherwise, you can just do model.double() to convert to float64.

@mwswift-nrl
Copy link
Author

OK this makes sense, thanks! This patch fixes the cli mace_select_head
0001-mace_select_head-retains-dtype-of-the-selected-model.patch

@ilyes319
Copy link
Contributor

merged in dev, closing for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants