Jittor implementation of Vision Transformer. Pytorch code is from lucidrains and Ross Wightman .
Lucidrains' code is a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. And Ross Wightman'code can load the pretrained models of Vision Transformer.
In order to distinguish between them, we are called vit_v1 and vit respectively.
Vit v1 (Lucidrains' code)
Framework | Test | Train |
---|---|---|
Pytorch | 7s | 29s |
Jittor | 6s | 24s |
Speed | 1.17 | 1.21 |
ViT ( Ross Wightman'code)
Framework | Test | Train |
---|---|---|
Pytorch | 1.007s | 0.388s |
Jittor | 0.988s | 0.325s |
Speed | 1.02 | 1.19 |
Dataset: Cat & Dog
Model:
from models.vit_v1 import ViT
model = ViT(
dim=128,
image_size=224,
patch_size=32,
num_classes=2,
depth=12,
heads=8,
mlp_dim=128
)
For cuda usage,we can use it with this.
jt.flags.use_cuda = 1
Dataloader:
import jittor.transform as transforms
from jittor.dataset import Dataset
## Image Augumentation
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomCropAndResize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
## Load Datasets
class CatsDogsDataset(Dataset):
def __init__(self, file_list, transform=None,batch_size=1,shuffle=False,num_workers=0):
super(CatsDogsDataset,self).__init__(batch_size=batch_size,shuffle=shuffle,num_workers=num_workers)
self.file_list = file_list
self.transform = transform
self.total_len=len(self.file_list)
def __getitem__(self, idx):
img_path = self.file_list[idx]
img = Image.open(img_path)
img_transformed = self.transform(img)
label = img_path.split("/")[-1].split(".")[0]
label = 1 if label == "dog" else 0
return img_transformed, label
Train & Test
### Training
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_data):
output = model(data)
loss = criterion(output, label)
optimizer.step(loss)
acc = (output.argmax(dim=1)[0] == label).float().mean()
epoch_accuracy += acc / len(train_data)
epoch_loss += loss / len(train_data)
with jt.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in tqdm(valid_data):
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1)[0] == label).float().mean()
epoch_val_accuracy += acc / len(valid_data)
epoch_val_loss += val_loss / len(valid_data)
jt.sync_all(True)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss.item():.4f} - acc: {epoch_accuracy.item():.4f} - val_loss : {epoch_val_loss.item():.4f} - val_acc: {epoch_val_accuracy.item():.4f}\n"
)
Dataset: Imagenet
Model Names: vit_huge_patch32_384, vit_huge_patch16_224, vit_large_patch32_384, vit_large_patch16_384, vit_large_patch16_224, vit_base_patch32_384, vit_base_patch16_384, vit_base_patch16_224, vit_small_patch16_224
Train & Validate: Use train.py to train your ViT and validate.py to validate your ViT.
Use this to set configs.
jt.flags.use_cuda = 1
model_name = 'vit_base_patch16_224'
lr = 0.001
train_dir = '/data/imagenet/train'
eval_dir = '/data/imagenet/val'
batch_size = 32
input_size = 224
num_workers = 4
hflip = 0.5
ratio = (0.75,1.3333333333333333)
scale = (0.08,1.0)
train_interpolation = 'random'
num_epochs = 8
You can use this to create your model and dataset.
model = create_model('vit_base_patch16_224',pretrained=True,num_classes=1000)
criterion = nn.CrossEntropyLoss()
dataset = create_val_dataset(root='/data/imagenet',batch_size=bs,num_workers=4,img_size=224)