Skip to content

Commit

Permalink
Model splitting and saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Aug 12, 2024
1 parent 476869c commit 70ac57d
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 124 deletions.
2 changes: 1 addition & 1 deletion mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# from mammoth.modules.embeddings import prepare_pretrained_embeddings
from mammoth.utils.logging import init_logger, logger

from mammoth.models.model_saver import load_checkpoint
from mammoth.utils.model_saver import load_checkpoint
from mammoth.train_single import main as single_main
from mammoth.inputters import DynamicDatasetIter

Expand Down
15 changes: 15 additions & 0 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.nn as nn
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto
from typing import Set, Any, Optional, Dict
Expand Down Expand Up @@ -69,6 +70,10 @@ def named_parameters(self, model: NMTModel):
module = self.get_module(model)
yield from module.named_parameters()

def state_dict(self, model: NMTModel, prefix='', keep_vars=False):
module = self.get_module(model)
return module.state_dict(prefix=prefix, keep_vars=keep_vars)

@property
def min_rank(self) -> int:
return min(self.global_ranks)
Expand Down Expand Up @@ -100,6 +105,16 @@ def named_parameters(self, model: NMTModel):
if 'embeddings' not in name and 'adapter' not in name:
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
if name == 'adapters':
# Adapters are stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination


@dataclass
class DistributedEncoder(DistributedXCoder):
Expand Down
11 changes: 9 additions & 2 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import torch
import torch.nn as nn
from collections import defaultdict
from collections import defaultdict, OrderedDict
from functools import partial
from pathlib import Path
from torch.nn.init import xavier_uniform_
Expand All @@ -31,11 +31,18 @@
from mammoth.modules.layer_stack import AdaptedAttentionLayersStack, StackXcoder
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu
from mammoth.utils.module_splitter import _combine_ordered_dicts
from mammoth.utils.parse import ArgumentParser
from mammoth.utils.transformer_wrapper import TransformerWrapper


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
for prefix, input_dict in input_dicts.items():
for key, item in input_dict.items():
result.append((f'{prefix}{key}', item))
return OrderedDict(result)


def uses_adapters(opts):
return 'adapters' in opts and opts.adapters

Expand Down
3 changes: 1 addition & 2 deletions mammoth/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Module defining models."""
from mammoth.models.model_saver import build_model_saver, ModelSaver
from mammoth.models.model import NMTModel

__all__ = ["build_model_saver", "ModelSaver", "NMTModel"]
__all__ = ["NMTModel"]
11 changes: 9 additions & 2 deletions mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mammoth.utils.optimizers import MultipleOptimizer
from mammoth.utils.misc import set_random_seed
from mammoth.trainer import build_trainer
from mammoth.models import build_model_saver
from mammoth.utils.model_saver import build_model_saver
from mammoth.utils.logging import init_logger, logger
from mammoth.utils.parse import ArgumentParser

Expand Down Expand Up @@ -127,7 +127,14 @@ def main(
)

# Build model saver
model_saver = build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context)
model_saver = build_model_saver(
model_opts,
opts,
model,
vocabs_dict,
optim,
task_queue_manager=task_queue_manager,
)

logger.info("{} - Build trainer".format(device_context.id))
trainer = build_trainer(
Expand Down
83 changes: 32 additions & 51 deletions mammoth/models/model_saver.py → mammoth/utils/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from mammoth.utils.module_splitter import explode_model


def build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context):
def build_model_saver(model_opts, opts, model, vocabs_dict, optim, task_queue_manager):
# _check_save_model_path
save_model_path = os.path.abspath(opts.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)

model_saver = ModelSaver(
opts.save_model, model, model_opts, vocabs_dict, optim, opts.keep_checkpoint, device_context, opts.save_all_gpus
opts.save_model, model, model_opts, vocabs_dict, optim, opts.keep_checkpoint, task_queue_manager, opts.save_all_gpus
)
return model_saver

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
vocabs_dict,
optim,
keep_checkpoint=-1,
device_context=None,
task_queue_manager=None,
all_gpus=False,
):
self.base_path = base_path
Expand All @@ -62,8 +62,8 @@ def __init__(
self.keep_checkpoint = keep_checkpoint
if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
assert device_context is not None
self.device_context = device_context
assert task_queue_manager is not None
self.task_queue_manager = task_queue_manager
self.all_gpus = all_gpus

def save(self, step, data_state, moving_average=None):
Expand All @@ -83,7 +83,7 @@ def save(self, step, data_state, moving_average=None):
model_params_data.append(param.data)
param.data = avg.data

chkpt_names = self._save(step, save_model, data_state, self.device_context)
chkpt_names = self._save(step, save_model, data_state, self.task_queue_manager)
self.last_saved_step = step

if moving_average:
Expand All @@ -96,14 +96,14 @@ def save(self, step, data_state, moving_average=None):
self._rm_checkpoint(todel)
self.checkpoint_queue.append(chkpt_names)

def _save(self, step, save_model, data_state, device_context):
def _save(self, step, save_model, data_state, task_queue_manager):
"""Save a resumable checkpoint.
Args:
step (int): step number
save_model (nn.Module): torch model to save
data_state (dict): data streaming info
device_context: runtime info
task_queue_manager: distributed structure of modules
Returns:
(object, str):
Expand All @@ -128,60 +128,41 @@ def _rm_checkpoint(self, name):
class ModelSaver(ModelSaverBase):
"""Simple model saver to filesystem"""

def _save(self, step, model, data_state, device_context):
real_model = model.module if isinstance(model, nn.DataParallel) else model

model_state_dict = real_model.state_dict()

checkpoint = {
"model": model_state_dict,
# 'generator': generator_state_dict,
"vocab": self.vocabs_dict,
"opts": self.model_opts,
"optim": self.optim.state_dict(),
"whole_model": self.model,
}
def _save(self, step, model, data_state, task_queue_manager):
model = model.module if isinstance(model, nn.DataParallel) else model
device_context = task_queue_manager.device_context

tmp_checkpoint_paths = []

if self.all_gpus:
# save models trained in each gpu
checkpoint_path = "{}_step_{}_gpu_{}.pt".format(self.base_path, step, device_context.global_rank)
logger.info("Saving full checkpoint {}".format(checkpoint_path))
torch.save(checkpoint, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)

modules, model_frame = explode_model(checkpoint)
module_state_dicts = explode_model(model, task_queue_manager)

for key, module in modules.items():
# All processes will try to save the modules present on that device
# Not that a race condition is possible:
# the process can be preempted after the check for existence, but before the save.
# This shouldn't be a problem, if writes are atomic.
checkpoint_path = f'{self.base_path}_step_{step}_{key}.pt'
if os.path.isfile(checkpoint_path):
logger.debug("{} - not saving {} as it is already present".format(device_context.id, checkpoint_path))
else:
logger.info("Saving module checkpoint {}".format(checkpoint_path))
torch.save(module, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)
# The master device stores the frame
if device_context.is_master():
module_state_dicts['frame'] = {
'vocab': self.vocabs_dict,
'opts': self.model_opts,
'optim': self.optim.state_dict(),
}

# In a distributed context, aggregate all data states for corpus restoration
if device_context.is_distributed():
data_states = [None for _ in range(device_context.world_size)]
torch.distributed.all_gather_object(data_states, data_state)
data_state = {k: v for state in data_states for k, v in state.items()}
if device_context.is_master():
module_state_dicts['frame']['data_state'] = data_state

model_frame['data_state'] = data_state
if device_context.is_master():
# TODO: not sure how to deal with model_state_dict, fields, model_opts and optim.state_dict() in a multi-gpu
# setting. Is it OK to save only from master?

# model frame
checkpoint_path = "{}_step_{}_frame.pt".format(self.base_path, step)
logger.info("Saving model frame checkpoint {}".format(checkpoint_path))
torch.save(model_frame, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)
for key, state_dict in module_state_dicts.items():
# The state_dicts across different devices only contain one copy of each module:
# on the lowest ranked device having that module.
# There is no race condition.
checkpoint_path = f'{self.base_path}_step_{step}_{key}.pt'
if os.path.isfile(checkpoint_path):
logger.debug("{} - not saving {} as it is already present".format(device_context.id, checkpoint_path))
else:
logger.info("Saving module checkpoint {}".format(checkpoint_path))
torch.save(state_dict, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)

return tmp_checkpoint_paths

Expand Down
78 changes: 12 additions & 66 deletions mammoth/utils/module_splitter.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,16 @@
from collections import OrderedDict
from typing import Dict
from typing import Dict, Any

from mammoth.models import NMTModel
from mammoth.distributed.tasks import LocalTaskQueueManager

def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
for prefix, input_dict in input_dicts.items():
for key, item in input_dict.items():
result.append((f'{prefix}{key}', item))
return OrderedDict(result)


def explode_model(full_ab_model):
# FIXME: saving and loading are broken
encoder = full_ab_model["whole_model"].encoder
decoder = full_ab_model["whole_model"].decoder

modules = {}

# embeddings
for embedding_key, embeddings in encoder.embeddings.items():
lang = embedding_key.replace('embeddings_', '')
key = f'src_embeddings_{lang}'
modules[key] = embeddings.state_dict()
for embedding_key, embeddings in decoder.embeddings.items():
lang = embedding_key.replace('embeddings_', '')
key = f'tgt_embeddings_{lang}'
modules[key] = embeddings.state_dict()

# encoders
for layer_stack_idx, layer_stack_dict in enumerate(encoder.encoders):
for layer_stack_key, layer_stack in layer_stack_dict.items():
# the xcoder itself
key = f'encoder_{layer_stack_idx}_{layer_stack_key}'
modules[key] = layer_stack.state_dict(include_adapters=False)

# the adapters for this xcoder
for adapter_key, adapter in layer_stack.adapters.items():
adapter_key = adapter_key.replace('adapter_', '')
key = f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_key}'
modules[key] = adapter.state_dict()

# decoders
for layer_stack_idx, layer_stack_dict in enumerate(decoder.decoders):
for layer_stack_key, layer_stack in layer_stack_dict.items():
# the xcoder itself
key = f'decoder_{layer_stack_idx}_{layer_stack_key}'
modules[key] = layer_stack.state_dict(include_adapters=False)

# the adapters for this xcoder
for adapter_key, adapter in layer_stack.adapters.items():
adapter_key = adapter_key.replace('adapter_', '')
key = f'decoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_key}'
modules[key] = adapter.state_dict()

# generators
for generator_key, generator in full_ab_model["whole_model"].generator.items():
modules[generator_key] = generator.state_dict()

# attention bridge
modules['attention_bridge'] = full_ab_model['whole_model'].attention_bridge.state_dict()

# stuff necessary to build bilingual models combining modules
model_frame = {
"vocab": full_ab_model["vocab"],
"opts": full_ab_model["opts"],
"optim": full_ab_model["optim"],
}

return modules, model_frame
def explode_model(model: NMTModel, task_queue_manager: LocalTaskQueueManager) -> Dict[str, Any]:
my_components = task_queue_manager.get_my_distributed_components()
my_global_rank = task_queue_manager.global_rank
state_dicts = OrderedDict()
for component in my_components:
if component.min_rank == my_global_rank:
# Only the lowest ranked device saves a component
state_dicts[component.get_name()] = component.state_dict(model)
return state_dicts

0 comments on commit 70ac57d

Please sign in to comment.