Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf integration #1

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tools/huggingface/__init__.py
Empty file.
322 changes: 322 additions & 0 deletions tools/huggingface/cfg-baseexample.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
# Testing with 6 languages (total 18 language pairs)
# train_([a-z][a-z])-*\1 regex for matching only autoencoding tasks (and edit with multi-caret - thanks pycharm!)
#save_data: /scratch/project_2005099/data/unpc-all-directions-vocabs
src_vocab:
"es": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"cni": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"aym": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"bzd": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"gn": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"en": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"oto": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"nah": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"quy": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"tar": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"shp": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"hch": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"czn": /scratch/project_462000088/members/jrvc/data/all_32k.vocab

tgt_vocab:
"es": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"cni": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"aym": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"bzd": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"gn": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"en": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"oto": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"nah": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"quy": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"tar": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"shp": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"hch": /scratch/project_462000088/members/jrvc/data/all_32k.vocab
"czn": /scratch/project_462000088/members/jrvc/data/all_32k.vocab


# Tokenization options
src_subword_type: sentencepiece
src_subword_model: /scratch/project_462000088/members/jrvc/data/all_32k.model
tgt_subword_type: sentencepiece
tgt_subword_model: /scratch/project_462000088/members/jrvc/data/all_32k.model

# Number of candidates for SentencePiece sampling
subword_nbest: 20
# Smoothing parameter for SentencePiece sampling
subword_alpha: 0.1
# Specific arguments for pyonmttok
src_onmttok_kwargs: "{'mode': 'none', 'spacer_annotate': True}"
tgt_onmttok_kwargs: "{'mode': 'none', 'spacer_annotate': True}"


overwrite: False

data:
## GPU 0: es-cni, cni, es-aym, aym, es-tar, tar
train_es-cni:
src_tgt: es-cni
node_gpu: "0:0"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["cni", "all", "cni"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-cni.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-cni.cni
transforms: [onmt_tokenize, filtertoolong]
train_cni-cni:
src_tgt: cni-cni
node_gpu: "0:0"
enc_sharing_group: ["cni", "all"]
dec_sharing_group: ["cni", "all", "cni"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-cni.cni
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-cni.cni
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-aym:
src_tgt: es-aym
node_gpu: "0:0"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["aym", "all", "aym"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-aym.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-aym.aym
transforms: [onmt_tokenize, filtertoolong]
train_aym-aym:
src_tgt: aym-aym
node_gpu: "0:0"
enc_sharing_group: ["aym", "all"]
dec_sharing_group: ["aym", "all", "aym"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-aym.aym
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-aym.aym
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-tar:
src_tgt: es-tar
node_gpu: "0:0"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["tar", "all", "tar"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-tar.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-tar.tar
transforms: [onmt_tokenize, filtertoolong]
train_tar-tar:
src_tgt: tar-tar
node_gpu: "0:0"
enc_sharing_group: ["tar", "all"]
dec_sharing_group: ["tar", "all", "tar"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-tar.tar
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-tar.tar
transforms: [onmt_tokenize, filtertoolong, bart]


## GPU 1: es-bzd, bzd, es-gn, gn, es-shp, shp
train_es-bzd:
src_tgt: es-bzd
node_gpu: "0:1"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["bzd", "all", "bzd"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-bzd.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-bzd.bzd
transforms: [onmt_tokenize, filtertoolong]
train_bzd-bzd:
src_tgt: bzd-bzd
node_gpu: "0:1"
enc_sharing_group: ["bzd", "all"]
dec_sharing_group: ["bzd", "all", "bzd"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-bzd.bzd
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-bzd.bzd
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-gn:
src_tgt: es-gn
node_gpu: "0:1"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["gn", "all", "gn"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-gn.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-gn.gn
transforms: [onmt_tokenize, filtertoolong]
train_gn-gn:
src_tgt: gn-gn
node_gpu: "0:1"
enc_sharing_group: ["gn", "all"]
dec_sharing_group: ["gn", "all", "gn"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-gn.gn
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-gn.gn
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-shp:
src_tgt: es-shp
node_gpu: "0:1"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["shp", "all", "shp"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-shp.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-shp.shp
transforms: [onmt_tokenize, filtertoolong]
train_shp-shp:
src_tgt: shp-shp
node_gpu: "0:1"
enc_sharing_group: ["shp", "all"]
dec_sharing_group: ["shp", "all", "shp"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-shp.shp
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-shp.shp
transforms: [onmt_tokenize, filtertoolong, bart]

## GPU 2: es-en, en, es-oto, oto, es-hch, hch
train_es-en:
src_tgt: es-en
node_gpu: "0:2"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["en", "all", "en"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-en.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-en.en
transforms: [onmt_tokenize, filtertoolong]
train_en-en:
src_tgt: en-en
node_gpu: "0:2"
enc_sharing_group: ["en", "all"]
dec_sharing_group: ["en", "all", "en"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-en.en
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-en.en
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-oto:
src_tgt: es-oto
node_gpu: "0:2"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["oto", "all", "oto"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-oto.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-oto.oto
transforms: [onmt_tokenize, filtertoolong]
train_oto-oto:
src_tgt: oto-oto
node_gpu: "0:2"
enc_sharing_group: ["oto", "all"]
dec_sharing_group: ["oto", "all", "oto"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-oto.oto
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-oto.oto
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-hch:
src_tgt: es-hch
node_gpu: "0:2"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["hch", "all", "hch"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-hch.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-hch.hch
transforms: [onmt_tokenize, filtertoolong]
train_hch-hch:
src_tgt: hch-hch
node_gpu: "0:2"
enc_sharing_group: ["hch", "all"]
dec_sharing_group: ["hch", "all", "hch"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-hch.hch
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-hch.hch
transforms: [onmt_tokenize, filtertoolong, bart]

## GPU 3: es-nah, nah, es-quy, quy, es-czn, czn, es
train_es-nah:
src_tgt: es-nah
node_gpu: "0:3"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["nah", "all", "nah"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-nah.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-nah.nah
transforms: [onmt_tokenize, filtertoolong]
train_nah-nah:
src_tgt: nah-nah
node_gpu: "0:3"
enc_sharing_group: ["nah", "all"]
dec_sharing_group: ["nah", "all", "nah"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-nah.nah
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-nah.nah
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-quy:
src_tgt: es-quy
node_gpu: "0:3"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["quy", "all", "quy"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-quy.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-quy.quy
transforms: [onmt_tokenize, filtertoolong]
train_quy-quy:
src_tgt: quy-quy
node_gpu: "0:3"
enc_sharing_group: ["quy", "all"]
dec_sharing_group: ["quy", "all", "quy"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-quy.quy
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-quy.quy
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-czn:
src_tgt: es-czn
node_gpu: "0:3"
enc_sharing_group: ["czn", "all"]
dec_sharing_group: ["czn", "all", "czn"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-czn.es
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-czn.czn
transforms: [onmt_tokenize, filtertoolong]
train_czn-czn:
src_tgt: czn-czn
node_gpu: "0:0"
enc_sharing_group: ["czn", "all"]
dec_sharing_group: ["czn", "all", "czn"]
path_src: /scratch/project_462000088/members/jrvc/data/train.es-czn.czn
path_tgt: /scratch/project_462000088/members/jrvc/data/train.es-czn.czn
transforms: [onmt_tokenize, filtertoolong, bart]
train_es-es:
src_tgt: es-es
node_gpu: "0:3"
enc_sharing_group: ["es", "all"]
dec_sharing_group: ["es", "all", "es"]
path_src: /scratch/project_462000088/members/jrvc/data/all-es.es
path_tgt: /scratch/project_462000088/members/jrvc/data/all-es.es
transforms: [onmt_tokenize, filtertoolong, bart]






### Transform related opts:
#### Filter
src_seq_length: 200
tgt_seq_length: 200
#### Bart
src_subword_type: sentencepiece
tgt_subword_type: sentencepiece
mask_ratio: 0.2
replace_length: 1

# silently ignore empty lines in the data
skip_empty_level: silent

batch_size: 4096
batch_type: tokens
normalization: tokens
valid_batch_size: 4096
max_generator_batches: 2
src_vocab_size: 100000
tgt_vocab_size: 100000
encoder_type: transformer
decoder_type: transformer
rnn_size: 512
word_vec_size: 512
transformer_ff: 2048
heads: 8
enc_layers: [3,3]
dec_layers: [2,2,2]
dropout: 0.1
label_smoothing: 0.1
param_init: 0.0
param_init_glorot: true
position_encoding: true
valid_steps: 1000000
warmup_steps: 20000
#warmup_steps: 40
report_every: 50
#report_every: 5
save_checkpoint_steps: 50000
keep_checkpoint: -1
accum_count: 1
optim: adafactor
decay_method: none
learning_rate: 3.0
weight_decay: 0.05
max_grad_norm: 0.0
seed: 3435
model_type: text
save_all_gpus: false

world_size: 4
gpu_ranks: [0,1,2,3]
node_rank: 0



64 changes: 64 additions & 0 deletions tools/huggingface/convert_mammoth_to_marian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from transformers import MarianConfig
import sys
import yaml

from onmt.utils.parse import ArgumentParser
from onmt.opts import train_opts


def _get_parser():
parser = ArgumentParser(description='convert_mammoth_to_marian.py')
train_opts(parser)
parser.add_argument('hf_ckpt_dir', required=True)
return parser





def convert_mammoth_to_marian(mammoth_config_path: str, marian_config_path: str) -> None:
"""
"""
parser = _get_parser()

opt, unknown = parser.parse_known_args()

config_dict = opt # yaml.safe_load(mammoth_config)
marian = MarianConfig(
vocab_size=config_dict["src_vocab_size"],
decoder_vocab_size=config_dict["src_vocab_size"],
max_position_embeddings=1024, # default
encoder_layers=config_dict["enc_layers"][0],
encoder_ffn_dim=config_dict["transformer_ff"],
encoder_attention_heads=config_dict["heads"],
decoder_layers=config_dict["dec_layers"][0],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a sum, not a [0]

decoder_ffn_dim=config_dict["transformer_ff"],
decoder_attention_heads=config_dict["heads"],
encoder_layerdrop=0.0, # default
decoder_layerdrop=0.0, # default
use_cache=True, # default
is_encoder_decoder=True, # default
activation_function=config_dict['pos_ffn_activation_fn'], # default
d_model=config_dict['rnn_size'], # default
dropout=confiog_dict['dropout'], # default
attention_dropout=0, # default
activation_dropout=0, # default
init_std=0.02, # default
decoder_start_token_id=58100, # default
scale_embedding=False, # default
pad_token_id=58100, # default
eos_token_id=0, # default
forced_eos_token_id=0, # default
share_encoder_decoder_embeddings=True, # default
)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--src", "-s", dest="mammoth_config_path", default=sys.stdin)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
convert_mammoth_to_marian(**args.__dict__)