-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_authorship_set_configs.py
80 lines (61 loc) · 2.95 KB
/
train_authorship_set_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from sacred import Experiment
from replica_learn.utils import TrainingParams, ModelParams
ex = Experiment()
PRETRAINED_DIR = '/mnt/cluster-nas/benoit/pretrained_nets'
@ex.config
def my_config():
training_dir = '/scratch/benoit/authorship'
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test.csv'
model_params = ModelParams().to_dict() # Model parameters
training_params = TrainingParams().to_dict() # Training parameters
@ex.named_config
def embeddings():
model_params = {
'class_embedding_dim': 128,
'fc_units': 1024
}
@ex.named_config
def resnet_50():
model_params = {
'base_model': 'resnet50',
'pretrained_file': os.path.join(PRETRAINED_DIR, 'resnet_v1_50.ckpt'),
'pretrained_name_scope': 'resnet_v1_50',
'learning_rate': 0.000005,
'blocks': 4
}
@ex.named_config
def rijks_374_u():
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names_374_u.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train_374_u.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation_374_u.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test_374_u.csv'
@ex.named_config
def rijks_374():
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names_374.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train_374.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation_374.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test_374.csv'
@ex.named_config
def rijks_100():
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names_100.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train_100.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation_100.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test_100.csv'
training_index_file = '/scratch/benoit/datasets/rijkschallenge/index_train_100.hdf5'
validation_index_file = '/scratch/benoit/datasets/rijkschallenge/index_validation_100.hdf5'
@ex.named_config
def rijks_200():
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names_200.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train_200.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation_200.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test_200.csv'
@ex.named_config
def rijks_300():
author_csv = '/scratch/benoit/datasets/rijkschallenge/author_names_300.csv'
training_csv = '/scratch/benoit/datasets/rijkschallenge/train_300.csv'
validation_csv = '/scratch/benoit/datasets/rijkschallenge/validation_300.csv'
testing_csv = '/scratch/benoit/datasets/rijkschallenge/test_300.csv'