You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using Vit base model creation with patch size 16, but getting an error when loading checkpoint:
Here is the vision transformer class I am using: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
size mismatch for pos_embed: copying a param with shape torch.Size([1, 198, 768]) from checkpoint, the shape in current model is torch.Size([1, 197, 768]).
Could you tell what could be the difference?
Thanks,
Rohan
The text was updated successfully, but these errors were encountered:
The difference lies in the pretext token. Besides class token, we append a new token to the sequence, resulting in 198 (196+1+1) tokens instead of 197 (196+1). We also release the modified model in vits.py, and it might help to create the model.
Sure, you might using the register_token in recent ViT models, so that the pos_embed would require a [1, 198, 768] tensor with 1 reg_token. However, we did not try this, so we are not sure whether it would affect the performance. We still recommend you using the following codes to initialize the model.
from vits import VisionTransformerMoCo
# init the model
model = VisionTransformerMoCo(pretext_token=True, global_pool='avg')
# init the fc layer
model.head = nn.Linear(768, args.num_classes)
# load checkpoint
checkpoint = torch.load(your_checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
# Your own tasks
# x_feats, x_output = model(x) where x_feats: [B, 198, 768] and x_output: [B, 768]
Hello,
I am using Vit base model creation with patch size 16, but getting an error when loading checkpoint:
Here is the vision transformer class I am using: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
size mismatch for pos_embed: copying a param with shape torch.Size([1, 198, 768]) from checkpoint, the shape in current model is torch.Size([1, 197, 768]).
Could you tell what could be the difference?
Thanks,
Rohan
The text was updated successfully, but these errors were encountered: