-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
79 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
"""Module defining models.""" | ||
from mammoth.models.model_saver import build_model_saver, ModelSaver | ||
from mammoth.models.model import NMTModel | ||
|
||
__all__ = ["build_model_saver", "ModelSaver", "NMTModel"] | ||
__all__ = ["NMTModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,16 @@ | ||
from collections import OrderedDict | ||
from typing import Dict | ||
from typing import Dict, Any | ||
|
||
from mammoth.models import NMTModel | ||
from mammoth.distributed.tasks import LocalTaskQueueManager | ||
|
||
def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict: | ||
result = [] | ||
for prefix, input_dict in input_dicts.items(): | ||
for key, item in input_dict.items(): | ||
result.append((f'{prefix}{key}', item)) | ||
return OrderedDict(result) | ||
|
||
|
||
def explode_model(full_ab_model): | ||
# FIXME: saving and loading are broken | ||
encoder = full_ab_model["whole_model"].encoder | ||
decoder = full_ab_model["whole_model"].decoder | ||
|
||
modules = {} | ||
|
||
# embeddings | ||
for embedding_key, embeddings in encoder.embeddings.items(): | ||
lang = embedding_key.replace('embeddings_', '') | ||
key = f'src_embeddings_{lang}' | ||
modules[key] = embeddings.state_dict() | ||
for embedding_key, embeddings in decoder.embeddings.items(): | ||
lang = embedding_key.replace('embeddings_', '') | ||
key = f'tgt_embeddings_{lang}' | ||
modules[key] = embeddings.state_dict() | ||
|
||
# encoders | ||
for layer_stack_idx, layer_stack_dict in enumerate(encoder.encoders): | ||
for layer_stack_key, layer_stack in layer_stack_dict.items(): | ||
# the xcoder itself | ||
key = f'encoder_{layer_stack_idx}_{layer_stack_key}' | ||
modules[key] = layer_stack.state_dict(include_adapters=False) | ||
|
||
# the adapters for this xcoder | ||
for adapter_key, adapter in layer_stack.adapters.items(): | ||
adapter_key = adapter_key.replace('adapter_', '') | ||
key = f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_key}' | ||
modules[key] = adapter.state_dict() | ||
|
||
# decoders | ||
for layer_stack_idx, layer_stack_dict in enumerate(decoder.decoders): | ||
for layer_stack_key, layer_stack in layer_stack_dict.items(): | ||
# the xcoder itself | ||
key = f'decoder_{layer_stack_idx}_{layer_stack_key}' | ||
modules[key] = layer_stack.state_dict(include_adapters=False) | ||
|
||
# the adapters for this xcoder | ||
for adapter_key, adapter in layer_stack.adapters.items(): | ||
adapter_key = adapter_key.replace('adapter_', '') | ||
key = f'decoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_key}' | ||
modules[key] = adapter.state_dict() | ||
|
||
# generators | ||
for generator_key, generator in full_ab_model["whole_model"].generator.items(): | ||
modules[generator_key] = generator.state_dict() | ||
|
||
# attention bridge | ||
modules['attention_bridge'] = full_ab_model['whole_model'].attention_bridge.state_dict() | ||
|
||
# stuff necessary to build bilingual models combining modules | ||
model_frame = { | ||
"vocab": full_ab_model["vocab"], | ||
"opts": full_ab_model["opts"], | ||
"optim": full_ab_model["optim"], | ||
} | ||
|
||
return modules, model_frame | ||
def explode_model(model: NMTModel, task_queue_manager: LocalTaskQueueManager) -> Dict[str, Any]: | ||
my_components = task_queue_manager.get_my_distributed_components() | ||
my_global_rank = task_queue_manager.global_rank | ||
state_dicts = OrderedDict() | ||
for component in my_components: | ||
if component.min_rank == my_global_rank: | ||
# Only the lowest ranked device saves a component | ||
state_dicts[component.get_name()] = component.state_dict(model) | ||
return state_dicts |