-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain_gen.yml
45 lines (37 loc) · 1.88 KB
/
train_gen.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
environment:
image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
setup:
- pip install transformers==2.11.0
- pip install spacy
- pip install tqdm
- pip install tensorboard
- pip install wandb
- pip install pytorch-lightning==0.9.0
- pip install rouge-score
- sudo apt-get update
- sudo apt-get -y install git
- git clone https://github.com/NVIDIA/apex
- cd apex
- pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
- cd ..
jobs:
# train GAR with 'answer' as the target (see conf.py for other options)
- name: generator_train_nq_A
command:
- cd gar
- GEN_TARGET='answer' python train_generator.py --remark generator_train_nq_A --train_batch_size 128 --eval_batch_size 256 --ckpt_metric val-ROUGE-1
# train the generative reader
- name: train_multi_inputs_trivia
command:
- cd gar
- python train_generator.py --remark train_multi_inputs_trivia --output_dir $$OUTPUT_DIR --data_dir $$DATA_DIR/trivia-multi-inputs --train_batch_size 6 --eval_batch_size 12 --max_source_length 1024 --max_target_length 10 --learning_rate 5e-6 --ckpt_metric val-EM
# test the generative reader
- name: test_multi_inputs
command:
- cd gar
- python test_generator.py --input_path $$DATA_DIR/nq-multi-inputs-merge4-CV/test.source --output_path $$DATA_DIR/nq-multi-inputs-merge4-CV/gen-test-samples.txt --bs 32 --model_ckpt checkpoint.ckpt --max_source_length 1024 --max_target_length 10 --remark test_multi_inputs
# train a closed-book generative reader
- name: train_closedBook
command:
- cd gar
- python train_generator.py --remark train_closedBook --output_dir $$OUTPUT_DIR --data_dir $$DATA_DIR/closedBook --train_batch_size 512 --eval_batch_size 1024 --learning_rate 1e-4 --ckpt_metric val-EM --max_source_length 20 --max_target_length 10 --num_train_epochs 50 --ckpt_metric val-EM