-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathlora.py
151 lines (115 loc) · 5.88 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn.functional as F
from utils import *
from loralib.utils import mark_only_lora_as_trainable, apply_lora, get_lora_parameters, lora_state_dict, save_lora, load_lora
from loralib import layers as lora_layers
def evaluate_lora(args, clip_model, loader, dataset):
clip_model.eval()
with torch.no_grad():
template = dataset.template[0]
texts = [template.format(classname.replace('_', ' ')) for classname in dataset.classnames]
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
texts = clip.tokenize(texts).cuda()
class_embeddings = clip_model.encode_text(texts)
text_features = class_embeddings/class_embeddings.norm(dim=-1, keepdim=True)
acc = 0.
tot_samples = 0
with torch.no_grad():
for i, (images, target) in enumerate(loader):
images, target = images.cuda(), target.cuda()
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
image_features = clip_model.encode_image(images)
image_features = image_features/image_features.norm(dim=-1, keepdim=True)
cosine_similarity = image_features @ text_features.t()
acc += cls_acc(cosine_similarity, target) * len(cosine_similarity)
tot_samples += len(cosine_similarity)
acc /= tot_samples
return acc
def run_lora(args, clip_model, logit_scale, dataset, train_loader, val_loader, test_loader):
VALIDATION = False
# Textual features
print("\nGetting textual features as CLIP's classifier.")
textual_features = clip_classifier(dataset.classnames, dataset.template, clip_model)
# Pre-load val features
print("\nLoading visual features and labels from val set.")
val_features, val_labels = pre_load_features(clip_model, val_loader)
# Pre-load test features
print("\nLoading visual features and labels from test set.")
test_features, test_labels = pre_load_features(clip_model, test_loader)
test_features = test_features.cuda()
test_labels = test_labels.cuda()
# Zero-shot CLIP
clip_logits = logit_scale * test_features @ textual_features
zs_acc = cls_acc(clip_logits, test_labels)
print("\n**** Zero-shot CLIP's test accuracy: {:.2f}. ****\n".format(zs_acc))
test_features = test_features.cpu()
test_labels = test_labels.cpu()
list_lora_layers = apply_lora(args, clip_model)
clip_model = clip_model.cuda()
if args.eval_only:
load_lora(args, list_lora_layers)
acc_test = evaluate_lora(args, clip_model, test_loader, dataset)
print("**** Test accuracy: {:.2f}. ****\n".format(acc_test))
return
mark_only_lora_as_trainable(clip_model)
total_iters = args.n_iters * args.shots
optimizer = torch.optim.AdamW(get_lora_parameters(clip_model), weight_decay=1e-2, betas=(0.9, 0.999), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_iters, eta_min=1e-6)
best_acc_val, best_acc_test = 0., 0.
best_epoch_val = 0
# training LoRA
scaler = torch.cuda.amp.GradScaler()
count_iters = 0
finish = False
while count_iters < total_iters:
clip_model.train()
acc_train = 0
tot_samples = 0
loss_epoch = 0.
if args.encoder == 'vision':
text_features = textual_features.t().half()
for i, (images, target) in enumerate(tqdm(train_loader)):
template = dataset.template[0]
texts = [template.format(classname.replace('_', ' ')) for classname in dataset.classnames]
images, target = images.cuda(), target.cuda()
if args.encoder == 'text' or args.encoder == 'both':
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
texts = clip.tokenize(texts).cuda()
class_embeddings = clip_model.encode_text(texts)
text_features = class_embeddings/class_embeddings.norm(dim=-1, keepdim=True)
if args.encoder == 'vision' or args.encoder == 'both':
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
image_features = clip_model.encode_image(images)
else:
with torch.no_grad():
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
image_features = clip_model.encode_image(images)
image_features = image_features/image_features.norm(dim=-1, keepdim=True)
cosine_similarity = logit_scale * image_features @ text_features.t()
loss = F.cross_entropy(cosine_similarity, target)
acc_train += cls_acc(cosine_similarity, target) * target.shape[0]
loss_epoch += loss.item() * target.shape[0]
tot_samples += target.shape[0]
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
count_iters += 1
if count_iters == total_iters:
break
if count_iters < total_iters:
acc_train /= tot_samples
loss_epoch /= tot_samples
current_lr = scheduler.get_last_lr()[0]
print('LR: {:.6f}, Acc: {:.4f}, Loss: {:.4f}'.format(current_lr, acc_train, loss_epoch))
# Eval
if VALIDATION:
clip_model.eval()
acc_val = evaluate_lora(args, clip_model, val_loader, dataset)
print("**** Val accuracy: {:.2f}. ****\n".format(acc_val))
acc_test = evaluate_lora(args, clip_model, test_loader, dataset)
print("**** Final test accuracy: {:.2f}. ****\n".format(acc_test))
if args.save_path != None:
save_lora(args, list_lora_layers)
return