-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathowkin_phikon.py
51 lines (41 loc) · 1.76 KB
/
owkin_phikon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch import nn
from transformers import ViTModel, AutoModel
class OwkinPhikonFeatureExtractor(nn.Module):
"""
OwkinPhikonFeatureExtractor is a feature extractor model based on different versions of the ViTModel.
Args:
version (str): The version of the model to use. Options are "v1" and "v2".
Raises:
ValueError: If an invalid version is provided.
Version Details:
- v1 (owkin/phikon):
- outputs
- last_hidden_state [batch_size, 197, 768]
- batch_size images
- 197 tokens = 1 cls token + 14*14 patch tokens
- 768 features
- v2 (owkin/phikon-v2):
- outputs
- pooler_output [batch_size, 1024] - the same as cls token at index 0
- last_hidden_state [batch_size, 197, 1024]
- batch_size images
- 197 tokens = 1 cls token + 14*14 patch tokens
- 1024 features
"""
def __init__(self, version):
super().__init__()
self.version = version
if version == "v1":
self.model = ViTModel.from_pretrained(
"owkin/phikon", add_pooling_layer=False)
elif version == "v2":
self.model = AutoModel.from_pretrained("owkin/phikon-v2")
else:
raise ValueError(f"Invalid version: {version}")
def forward(self, x):
outputs = self.model(x)
cls_token = outputs.last_hidden_state[:, 0, :]
# patch_tokens = outputs.last_hidden_state[:, 1:, :]
# if self.version == "v2":
# assert torch.equal(cls_token, outputs.pooler_output), "Pooler output and cls token should be the same for version={self.version}."
return cls_token