-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtune.py
220 lines (196 loc) · 7.78 KB
/
tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import shutil
import tempfile
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Any, Literal
import delu
import optuna
import optuna.samplers
import optuna.trial
from loguru import logger
from typing_extensions import NotRequired, TypedDict
import lib
from lib.util import KWArgs
def _suggest(trial: optuna.trial.Trial, distribution: str, label: str, *args):
return getattr(trial, f'suggest_{distribution}')(label, *args)
def sample_config(
trial: optuna.trial.Trial,
space: bool | int | float | str | bytes | list | dict,
label_parts: list,
) -> Any:
if isinstance(space, bool | int | float | str | bytes):
return space
elif isinstance(space, dict):
if '_tune_' in space or '-tune-' in space:
distribution = space['_tune_'] if '_tune_' in space else space['-tune-']
# for complex cases, for example:
# [model]
# _tune_ = "complex-custom-distribution"
# <any keys and values for the distribution>
if distribution == 'int-power-of-two':
raise NotImplementedError()
else:
raise ValueError(f'Unknown distibution: "{distribution}"')
else:
return {
key: sample_config(trial, subspace, [*label_parts, key])
for key, subspace in space.items()
}
elif isinstance(space, list):
if not space:
return space
elif space[0] not in ['_tune_', '-tune-']:
return [
sample_config(trial, subspace, [*label_parts, i])
for i, subspace in enumerate(space)
]
else:
# space = ["_tune_"/"-tune-", distribution, distribution_arg_0, distribution_1, ...] # noqa: E501
_, distribution, *args = space
label = '.'.join(map(str, label_parts))
if distribution.startswith('?'):
default, args_ = args[0], args[1:]
if trial.suggest_categorical('?' + label, [False, True]):
return _suggest(trial, distribution.lstrip('?'), label, *args_)
else:
return default
elif distribution == "int-power-of-two":
power = trial.suggest_int(label, args[0], args[1], step=1)
return 2**power
elif distribution == '$list':
size, item_distribution, *item_args = args
return [
_suggest(trial, item_distribution, label + f'.{i}', *item_args)
for i in range(size)
]
else:
return _suggest(trial, distribution, label, *args)
class Config(TypedDict):
seed: int
function: str
space: dict[str, Any]
n_trials: NotRequired[int]
timeout: NotRequired[int]
sampler: NotRequired[KWArgs]
sampler_type: NotRequired[str]
# Using the following option can lead to high disk space consumption.
# The supposed usage is to set it to True or 'only-reports' for
# more expensive long-running tasks, but not for those numerous
# tabular tasks presented in this repository.
save_trials: NotRequired[bool | Literal['only-reports']]
def main(
config: Config, output: str | Path, *, force: bool = False, continue_: bool = False
) -> None | lib.JSONDict:
assert set(config) >= Config.__required_keys__
assert set(config) <= Config.__required_keys__ | Config.__optional_keys__
assert 'seed' not in config.get('sampler', {})
if not lib.start(output, force=force, continue_=continue_):
return None
lib.show_config(config) # type: ignore[code]
output = Path(output)
delu.random.seed(config['seed'])
report = lib.create_report(config) # type: ignore[code]
function: lib.MainFunction = lib.import_(config['function'])
n_trials = config.get('n_trials')
timeout = config.get('timeout')
if lib.get_checkpoint_path(output).exists():
del report
checkpoint = lib.load_checkpoint(output)
report, study, trial_reports, timer = (
checkpoint['report'],
checkpoint['study'],
checkpoint['trial_reports'],
checkpoint['timer'],
)
delu.random.set_state(checkpoint['random_state'])
if n_trials is not None:
n_trials -= len(study.trials)
if timeout is not None:
timeout -= timer.elapsed()
report.setdefault('continuations', []).append(len(study.trials))
print(
f'Resuming from checkpoint ({len(study.trials)} completed,'
f' {n_trials or "inf"} remaining)'
)
time.sleep(1.0)
else:
study = optuna.create_study(
direction='maximize',
sampler=getattr(optuna.samplers, config.get('sampler_type', 'TPESampler'))(
**config.get('sampler', {}), seed=config['seed']
),
)
trial_reports = []
timer = delu.tools.Timer()
def objective(trial: optuna.trial.Trial) -> float:
trial_config = sample_config(trial, config['space'], [])
kwargs = {}
save_trials = config.get('save_trials', False)
if save_trials:
trial_output = output / 'trials' / str(trial.number)
trial_output.parent.mkdir(exist_ok=True)
if trial_output.exists():
# Resuming the latest unfinished trial is not supported.
logger.warning(
'Removing the latest unfinished trial'
f' {trial_output.relative_to(lib.PROJECT_DIR)}'
)
shutil.rmtree(trial_output, True)
report = function(trial_config, trial_output, **kwargs)
if isinstance(save_trials, str) and save_trials == 'only-reports':
for path in trial_output.iterdir():
if path.is_dir():
shutil.rmtree(path, True)
elif path.name not in ('DONE', 'report.json'):
path.unlink()
else:
with tempfile.TemporaryDirectory(suffix=f'_trial_{trial.number}') as tmp:
report = function(trial_config, Path(tmp) / 'output', **kwargs)
assert report is not None
report['trial_id'] = trial.number
report['tuning_time'] = str(timer)
trial_reports.append(report)
delu.cuda.free_memory()
return report['metrics']['val']['score']
def callback(*_, **__):
report['best'] = trial_reports[study.best_trial.number]
report['time'] = str(timer)
report['n_completed_trials'] = len(trial_reports)
lib.dump_checkpoint(
output,
{
'report': report,
'study': study,
'trial_reports': trial_reports,
'timer': timer,
'random_state': delu.random.get_state(),
},
)
lib.dump_summary(output, lib.summarize(report))
lib.dump_report(output, report)
lib.backup_output(output)
print()
print(f'Time: {timer}')
print('The current best result:')
pprint(lib.summarize(report['best']), width=60, sort_dicts=False)
print()
timer.run()
# Ignore the progress bar warning.
warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
# Ignore the warnings about the deprecated suggest_* methods
warnings.filterwarnings('ignore', category=FutureWarning)
study.optimize(
objective,
n_trials=n_trials,
timeout=timeout,
callbacks=[callback],
show_progress_bar=True,
)
lib.dump_summary(output, lib.summarize(report))
lib.finish(output, report)
return report
if __name__ == '__main__':
lib.configure_libraries()
lib.run_MainFunction_cli(main)