Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size mismatch while extracting features #11

Open
rbareja25 opened this issue Jul 26, 2024 · 3 comments
Open

Size mismatch while extracting features #11

rbareja25 opened this issue Jul 26, 2024 · 3 comments

Comments

@rbareja25
Copy link

Hello,

I am using Vit base model creation with patch size 16, but getting an error when loading checkpoint:
Here is the vision transformer class I am using: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
size mismatch for pos_embed: copying a param with shape torch.Size([1, 198, 768]) from checkpoint, the shape in current model is torch.Size([1, 197, 768]).

Could you tell what could be the difference?

Thanks,
Rohan

@hsymm
Copy link
Collaborator

hsymm commented Jul 26, 2024

The difference lies in the pretext token. Besides class token, we append a new token to the sequence, resulting in 198 (196+1+1) tokens instead of 197 (196+1). We also release the modified model in vits.py, and it might help to create the model.

@rbareja25
Copy link
Author

I tried that but would need to modify a lot of code, would modifying the number of tokens be suggested? or would change the results drastically?

@hsymm
Copy link
Collaborator

hsymm commented Jul 26, 2024

Sure, you might using the register_token in recent ViT models, so that the pos_embed would require a [1, 198, 768] tensor with 1 reg_token. However, we did not try this, so we are not sure whether it would affect the performance. We still recommend you using the following codes to initialize the model.

from vits import VisionTransformerMoCo
# init the model
model = VisionTransformerMoCo(pretext_token=True, global_pool='avg')
# init the fc layer
model.head = nn.Linear(768, args.num_classes)
# load checkpoint
checkpoint = torch.load(your_checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
# Your own tasks
# x_feats, x_output = model(x) where x_feats: [B, 198, 768] and x_output: [B, 768]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants