forked from RL-VIG/LibFewShot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_trainer_test.py
116 lines (84 loc) · 4.08 KB
/
run_trainer_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
import sys
import time
sys.dont_write_bytecode = True
from core.config import Config
from core import Trainer
import os
import torch
from core import Test
if __name__ == "__main__":
modelCollection = ["Baseline",
"Baseline++",
"RFS",
"SKD",
"MAML",
"VERSA",
"R2D2",
"LEO",
"MTL_meta",
"ANIL",
"BOIL"]
numberOfShotsCollection = [1,5,10]
backbonesCollection = ["Conv64F", "resnet12", "resnet18", "Conv32F"]
trialRunCollection = [1,2,3]
# reset final results
f = open("final_result.csv", "w")
f.write("")
f.close()
f = open("final_result.csv", "a")
f.write("Model" + "," + "Number of Shots" + "," + "Backbone" + ","+"Trial Number"+","+"Train Accuracy" + "," + "Best Train Accuracy" + "," + "Test 1 Accuracy" + ","+"Test 1 Best Accuracy" + ","+"Validation Accuracy"+"," + "Best Validation Accuracy" +","+"Test 2 Final Accuracy" + "," + "Test 2 Best Accuracy\n")
total_count = len(numberOfShotsCollection)*len(modelCollection)*len(backbonesCollection)*len(trialRunCollection)
total_count_index = 0
for numShots in numberOfShotsCollection:
for model in modelCollection:
for backbone in backbonesCollection:
for trial in trialRunCollection:
name = model+"_"+str(numShots)+"_"+backbone+"_"+str(trial)
printName = "Code is done with: " + name + " with progress: "+ str(total_count_index) + "/" + str(total_count)
# name = "test_run"
fileName = name + ".yaml"
f.write(model + "," + str(numShots) + "," + backbone + ","+str(trial)+",")
config = Config("config/"+fileName).get_config_dict()
rank = 0 # Set the rank to 0 for single GPU or CPU
trainer = Trainer(rank, config, name, f,printName) # Pass both rank and config arguments
trainer.train_loop(rank)
# swap out the test for a different test folder
os.rename("data/consolidated_seeds_dataset/test.csv", "data/consolidated_seeds_dataset/test_temp.csv")
time.sleep(1)
os.rename("data/consolidated_seeds_dataset/test_2.csv", "data/consolidated_seeds_dataset/test.csv")
PATH = "./results/"+name
VAR_DICT = {
"test_epoch": 5,
"n_gpu": 1,
"test_episode": 100,
"episode_size": 1,
}
config = Config(os.path.join(PATH, "config.yaml"), VAR_DICT).get_config_dict()
test = Test(0, config, f, printName, PATH)
test.test_loop()
f.write("\n")
os.rename("data/consolidated_seeds_dataset/test.csv", "data/consolidated_seeds_dataset/test_2.csv")
time.sleep(1)
os.rename("data/consolidated_seeds_dataset/test_temp.csv", "data/consolidated_seeds_dataset/test.csv")
# update progress bar
# update progress bar
total_count_index = total_count_index + 1
f.close()
# # -*- coding: utf-8 -*-
# import sys
# sys.dont_write_bytecode = True
# import torch
# import os
# from core.config import Config
# from core import Trainer
# def main(rank, config):
# trainer = Trainer(rank, config)
# trainer.train_loop(rank)
# if __name__ == "__main__":
# config = Config("./config/proto.yaml").get_config_dict()
# if config["n_gpu"] > 1:
# os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"]
# torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,))
# else:
# main(0, config)