-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
43 lines (33 loc) · 1.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import logging
import os
import oneflow as flow
from function import Trainer
from utils.utils_logging import init_logging
from utils.utils_config import get_config
def main(args):
cfg = get_config(args.config)
cfg.graph = args.graph
rank = flow.env.get_rank()
world_size = flow.env.get_world_size()
placement = flow.env.all_device_placement("cuda")
os.makedirs(cfg.output, exist_ok=True)
log_root = logging.getLogger()
init_logging(log_root, rank, cfg.output)
# root dir of loading checkpoint
load_path = None
for key, value in cfg.items():
num_space = 25 - len(key)
logging.info(": " + key + " " * num_space + str(value))
trainer = Trainer(cfg, placement, load_path, world_size, rank)
trainer()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OneFlow ArcFace Training")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument(
"--graph",
action="store_true",
help="Run model in graph mode,else run model in ddp mode.",
)
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
main(parser.parse_args())