forked from RL-VIG/LibFewShot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtpu_run_2.py
84 lines (52 loc) · 2.63 KB
/
tpu_run_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import time
sys.dont_write_bytecode = True
from core.config import Config
from core import Trainer
import os
import torch
from core import Test
if __name__ == "__main__":
nameCollection=["Baseline++_1_resnet18_1","Baseline++_1_Conv32F_1","RFS_1_Conv64F_1"]
modelCollection=["Baseline++","Baseline++","RFS"]
numShotsCollection=["1","1","1"]
backboneCollection=["resnet18","Conv32F","Conv64F"]
# reset final results
f = open("final_result.csv", "w")
f.write("")
f.close()
f = open("final_result.csv", "a")
f.write("Model" + "," + "Number of Shots" + "," + "Backbone" + ","+"Train Accuracy" + "," + "Best Train Accuracy" + "," + "Test 1 Accuracy" + ","+"Test 1 Best Accuracy" + ","+"Validation Accuracy"+"," + "Best Validation Accuracy" +","+"Test 2 Final Accuracy" + "," + "Test 2 Best Accuracy\n")
total_count = len(nameCollection)
total_count_index = 0
for name in nameCollection:
printName = "Code is done with: " + name + " with progress: "+ str(total_count_index) + "/" + str(total_count)
# name = "test_run"
fileName = name + ".yaml"
f.write(modelCollection[total_count_index] + "," + str(numShotsCollection[total_count_index]) + "," + backboneCollection[total_count_index] )
config = Config("config/"+fileName).get_config_dict()
rank = 0 # Set the rank to 0 for single GPU or CPU
trainer = Trainer(rank, config, name, f,printName) # Pass both rank and config arguments
trainer.train_loop(rank)
# swap out the test for a different test folder
os.rename("data/consolidated_seeds_dataset/test.csv", "data/consolidated_seeds_dataset/test_temp.csv")
time.sleep(1)
os.rename("data/consolidated_seeds_dataset/test_2.csv", "data/consolidated_seeds_dataset/test.csv")
PATH = "./results/"+name
VAR_DICT = {
"test_epoch": 5,
"n_gpu": 1,
"test_episode": 100,
"episode_size": 1,
}
config = Config(os.path.join(PATH, "config.yaml"), VAR_DICT).get_config_dict()
test = Test(0, config, f, printName, PATH)
test.test_loop()
f.write("\n")
os.rename("data/consolidated_seeds_dataset/test.csv", "data/consolidated_seeds_dataset/test_2.csv")
time.sleep(1)
os.rename("data/consolidated_seeds_dataset/test_temp.csv", "data/consolidated_seeds_dataset/test.csv")
# update progress bar
# update progress bar
total_count_index = total_count_index + 1
f.close()