Skip to content

Commit

Permalink
Merge branch 'main' into feats/bucket-lord
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Oct 16, 2023
2 parents c03d7cc + 165038e commit f981d26
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 38 deletions.
5 changes: 5 additions & 0 deletions docs/source/config_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ Note that if both `use_weight` and `use_introduce_at_training_step` are specific

Note that high-resource language pairs (would train for over 75% of the training time) all start at 0. This avoids starting training with only one GPU doing work, while the other GPUs are idle waiting for their LPs to start.

#### `use_src_lang_token`

Only has an effect when using the `prefix` transform.
Normally, the prefix transform only includes a target language selector token: `<to_yyy>` where `yyy` is the code of the target language.
If this flag is set, then also the source language is specified, e.g. `<from_xxx> <to_yyy>`.

#### `translation_config_dir`

Expand Down
2 changes: 1 addition & 1 deletion mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def validate_slurm_node_opts(current_env, world_context, opts):


def train(opts):
init_logger(opts.log_file)
init_logger(opts.log_file, structured_log_file=opts.structured_log_file)
ArgumentParser.validate_train_opts(opts)
ArgumentParser.update_model_opts(opts)
ArgumentParser.validate_model_opts(opts)
Expand Down
2 changes: 1 addition & 1 deletion mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
vocabs = {'src': src_vocab, 'tgt': tgt_vocab}
corpus_opts = opts.tasks[task.corpus_id]
transforms_to_apply = corpus_opts.get('transforms', None)
transforms_to_apply = transforms_to_apply or opts.get('transforms', None)
transforms_to_apply = transforms_to_apply or opts.transforms
transforms_to_apply = transforms_to_apply or []
transforms_cls = make_transforms(
opts,
Expand Down
23 changes: 20 additions & 3 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def config_opts(parser):
def _add_logging_opts(parser, is_train=True):
group = parser.add_argument_group('Logging')
group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.")
group.add(
'--structured_log_file',
'-structured_log_file',
type=str,
default="",
help="Output machine-readable structured logs to a file under this path."
)
group.add(
'--log_file_level',
'-log_file_level',
Expand Down Expand Up @@ -1193,9 +1200,6 @@ def translate_opts(parser, dynamic=False):
"Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}",
) # noqa: E501
group.add('--tgt', '-tgt', help='True target sequence (optional)')
group.add(
'--tgt_prefix', '-tgt_prefix', action='store_true', help='Generate predictions using provided `-tgt` as prefix.'
)
group.add(
'--shard_size',
'-shard_size',
Expand Down Expand Up @@ -1248,6 +1252,19 @@ def translate_opts(parser, dynamic=False):
# Adding options related to Transforms
_add_dynamic_transform_opts(parser)

group.add(
"--src_prefix",
"-src_prefix",
default="",
help="The encoder prefix, i.e. language selector token",
)
group.add(
"--tgt_prefix",
"-tgt_prefix",
default="",
help="The decoder prefix (FIXME: does not work, but must be set nevertheless)",
)


def build_bilingual_model(parser):
"""options for modular translation"""
Expand Down
2 changes: 1 addition & 1 deletion mammoth/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
mammoth.opts._add_train_general_opts(parser)

# -data option is required, but not used in this test, so dummy.
opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'])[0]
opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'], strict=False)[0]


class TestModel(unittest.TestCase):
Expand Down
11 changes: 6 additions & 5 deletions mammoth/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_transform_pipe(self):
"tgt": ["Bonjour", "le", "monde", "."],
}
# 4. apply transform pipe for example
ex_after = transform_pipe.apply(copy.deepcopy(ex), corpus_name="trainset")
ex_after = transform_pipe.apply(copy.deepcopy(ex), is_train=True, corpus_name="trainset")
# 5. example after the pipe exceed the length limit, thus filtered
self.assertIsNone(ex_after)
# 6. Transform statistics registed (here for filtertoolong)
Expand Down Expand Up @@ -121,8 +121,9 @@ def test_prefix(self):
}
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in)
prefix_transform.apply(ex_in, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, corpus_name="trainset")
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in, is_train=False, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, is_train=True, corpus_name="trainset")
self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆")
self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆")

Expand All @@ -135,10 +136,10 @@ def test_filter_too_long(self):
"src": ["Hello", "world", "."],
"tgt": ["Bonjour", "le", "monde", "."],
}
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIs(ex_out, ex_in)
filter_transform.tgt_seq_length = 2
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIsNone(ex_out)


Expand Down
35 changes: 27 additions & 8 deletions mammoth/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,44 @@ def get_specials(cls, opts):
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
# TODO: The following try/except is a hack to work around the different
# structure of opts during training vs translation, and the fact the transform
# does not know whether it is being warmed up for training or translation.
# This is most elegantly fixed by redesigning and unifying the formats of opts.
try:
# This should succeed during training
self.prefix_dict = self.get_prefix_dict(self.opts)
except AttributeError:
# Normal during translation
src_prefix = self.opts.src_prefix
tgt_prefix = self.opts.tgt_prefix
self.prefix_dict = {
'trans': {'src': src_prefix, 'tgt': tgt_prefix}
}

def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
example[side] = side_prefix.split() + example[side]
if example[side] is not None:
example[side] = side_prefix.split() + example[side]
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply prefix prepend to example.
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
if is_train:
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
else:
corpus_prefix = self.prefix_dict.get('trans', None)
if corpus_prefix is None:
raise ValueError('failed to set prefixes for translation')
return self._prepend(example, corpus_prefix)

def _repr_args(self):
Expand Down
3 changes: 1 addition & 2 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ def translate_dynamic(
if batch_size is None:
raise ValueError("batch_size must be set")

if self.tgt_prefix and tgt is None:
raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")
#
# data_iter = InferenceDataIterator(src, tgt, src_feats, transform)
#
Expand Down Expand Up @@ -474,6 +472,7 @@ def _translate(
transforms=transforms, # I suppose you might want *some* transforms
# batch_size=batch_size,
# batch_type=batch_type,
task=self.task,
).to(self._dev)

batches = build_dataloader(
Expand Down
40 changes: 27 additions & 13 deletions mammoth/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import os
import json
import logging
from logging.handlers import RotatingFileHandler
from typing import Dict, Union

logger = logging.getLogger()

Expand All @@ -13,6 +13,7 @@ def init_logger(
rotate=False,
log_level=logging.INFO,
gpu_id='',
structured_log_file=None,
):
log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s")
logger = logging.getLogger()
Expand All @@ -31,18 +32,31 @@ def init_logger(
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)

return logger
if structured_log_file:
init_structured_logger(structured_log_file)

return logger

def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False):
for k, v in lca_params.items():
lca_sum = v.sum().item()
lca_mean = v.mean().item()
lca_logs[k][f'STEP_{step}'] = {'sum': lca_sum, 'mean': lca_mean}

if dump_logs:
if os.path.exists(opath):
os.system(f'cp {opath} {opath}.previous')
with open(opath, 'w+') as f:
json.dump(lca_logs, f)
logger.info(f'dumped LCA logs in {opath}')
def init_structured_logger(
log_file=None,
):
# Log should be parseable as a jsonl file. Format should not include anything extra.
log_format = logging.Formatter("%(message)s")
logger = logging.getLogger("structured_logger")
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_file, mode='a', delay=True)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(log_format)
logger.handlers = [file_handler]
logger.propagate = False


def structured_logging(obj: Dict[str, Union[str, int, float]]):
structured_logger = logging.getLogger("structured_logger")
if not structured_logger.hasHandlers:
return
try:
structured_logger.info(json.dumps(obj))
except Exception:
pass
6 changes: 6 additions & 0 deletions mammoth/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ def defaults(cls, *args):
defaults = dummy_parser.parse_known_args([])[0]
return defaults

def parse_known_args(self, *args, strict=True, **kwargs):
opts, unknown = super().parse_known_args(*args, **kwargs)
if strict and unknown:
raise ValueError(f'unknown arguments provided:\n{unknown}')
return opts, unknown

@classmethod
def update_model_opts(cls, model_opts):
cls._validate_adapters(model_opts)
Expand Down
17 changes: 13 additions & 4 deletions mammoth/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import mammoth

from mammoth.utils.logging import logger
from mammoth.utils.logging import logger, structured_logging


def build_report_manager(opts, node_rank, local_rank):
Expand Down Expand Up @@ -140,7 +140,16 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(train_stats, "train", lr, patience, step)

if valid_stats is not None:
self.log('Validation perplexity: %g' % valid_stats.ppl())
self.log('Validation accuracy: %g' % valid_stats.accuracy())

ppl = valid_stats.ppl()
acc = valid_stats.accuracy()
self.log('Validation perplexity: %g', ppl)
self.log('Validation accuracy: %g', acc)
structured_logging({
'type': 'validation',
'step': step,
'learning_rate': lr,
'perplexity': ppl,
'accuracy': acc,
'crossentropy': valid_stats.xent(),
})
self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)
19 changes: 19 additions & 0 deletions tools/config_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ def set_transforms(opts):
else:
corpus['transforms'] = list(transforms)

if 'prefix' in corpus['transforms']:
if cc_opts.get('use_src_lang_token', False):
prefix = f'<from_{src}> <to_{tgt}>'
else:
prefix = f'<to_{tgt}>'
corpus['src_prefix'] = prefix
corpus['tgt_prefix'] = '' # does not work, but must be set nonetheless

duration = time.time() - start
logger.info(f'step took {duration} s')

Expand Down Expand Up @@ -589,6 +597,7 @@ def translation_configs(opts):

src_subword_model = opts.in_config[0].get('src_subword_model', None)
tgt_subword_model = opts.in_config[0].get('tgt_subword_model', None)
use_src_lang_token = cc_opts.get('use_src_lang_token', False)

os.makedirs(translation_config_dir, exist_ok=True)
encoder_stacks = defaultdict(dict)
Expand Down Expand Up @@ -637,6 +646,7 @@ def translation_configs(opts):
tgt_subword_model,
'supervised',
translation_config_dir,
use_src_lang_token,
)
supervised_pairs.add((src_lang, tgt_lang))
if zero_shot:
Expand Down Expand Up @@ -671,6 +681,7 @@ def translation_configs(opts):
tgt_subword_model,
'zeroshot',
translation_config_dir,
use_src_lang_token,
)

duration = time.time() - start
Expand All @@ -687,6 +698,7 @@ def _write_translation_config(
tgt_subword_model,
supervision,
translation_config_dir,
use_src_lang_token,
):
# specify on command line: --model, --src
result = {
Expand All @@ -699,6 +711,13 @@ def _write_translation_config(
}
if transforms:
result['transforms'] = transforms
if 'prefix' in transforms:
if use_src_lang_token:
prefix = f'<from_{src_lang}> <to_{tgt_lang}>'
else:
prefix = f'<to_{tgt_lang}>'
result['src_prefix'] = prefix
result['tgt_prefix'] = '' # does not work, but must be set nonetheless
if src_subword_model:
result['src_subword_model'] = src_subword_model
if tgt_subword_model:
Expand Down

0 comments on commit f981d26

Please sign in to comment.