-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_miniimagenet.py
executable file
·49 lines (38 loc) · 1.73 KB
/
run_miniimagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Train a model on miniImageNet.
"""
import random
import torch
from args import argument_parser, train_kwargs
from eval import do_evaluation
from image_loader import read_dataset
from train import train
from util import load_checkpoint
DATA_DIR='data/miniimagenet'
def main():
"""
Load data and train a model on it.
"""
args = argument_parser().parse_args()
print(args)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.set_device(0)
train_set, val_set, test_set = read_dataset(DATA_DIR)
if not args.pretrained:
print('Training...')
train(train_set, test_set, args.checkpoint, **train_kwargs(args))
else:
print('Restoring from checkpoint...')
model_state, op_state, meta_iteration, cur_meta_step_size, accuracy_tracking = load_checkpoint(args.checkpoint)
train(train_set, test_set, args.checkpoint, model_state, op_state, **train_kwargs(args))
print('\nEvaluating...')
model_state, op_state, meta_iteration, cur_meta_step_size, accuracy_tracking = load_checkpoint(args.checkpoint)
do_evaluation(model_state, op_state, args.checkpoint, val_set, test_set, train_set)
if __name__ == '__main__':
main()
"""
--shots 5 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-step-final 1 --meta-batch 5 --meta-iters 100000 --eval-batch 15 --eval-iters 50
--learning-rate 0.001 --train-shots 16 --checkpoint ckpt_m55 --cuda --pin_memory --pretrained
python -u run_miniimagenet.py --shots 5 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-step-final 0 --meta-batch 5 --meta-iters 100000 --eval-batch 15 --eval-iters 50 --learning-rate 0.001 --train-shots 16 --checkpoint ckpt_m55 --cuda --pin_memory --transductive --foml
"""