Skip to content

Commit

Permalink
ProtocolMixin: increase flexibility
Browse files Browse the repository at this point in the history
The load_protocol_file method in the protocol utils.py
module uses the path of the module file to obtain the path to the .yaml
file that described the protocols. However, this fails in case the
`ProtocolMixin` class is used by other packages, since the protocol files
of the package's work chains will be stored in that package.

Moreover, the user can only adapt the protocol by specifying
`overrides` as a dictionary, or adapting the builder afterwards.

Here we adapt the `ProtocolMixin` class as follows:

* Add a `get_protocol_filepath()` class method that needs to be
implemented by the work chain classes that use the `ProtocolMixin`.
* Move the `load_protocol_file` method into the `ProtocolMixin` class as
a hidden method that relies on the `get_protocol_filepath` method.
* The `get_protocol_inputs` method is made more flexible by accepting
a `pathlib.Path` for the `overrides` input. This allows the user to
define protocol overrides in a `.yaml` file.

The `Path` to the hardcoded protocols that ship with the package are now
also obtained using `importlib_resources`.
  • Loading branch information
mbercx committed May 6, 2021
1 parent 33cb11e commit a04a055
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 37 deletions.
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ def define(cls, spec):
spec.expose_outputs(DosCalculation, namespace='dos')
spec.expose_outputs(ProjwfcCalculation, namespace='projwfc')

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from . import protocols
return files(protocols) / 'pdos.yaml'

@classmethod
def get_builder_from_protocol(
cls, pw_code, dos_code, projwfc_code, structure, protocol=None, overrides=None, **kwargs
Expand All @@ -324,11 +331,10 @@ def get_builder_from_protocol(
:return: a process builder instance with all inputs defined ready for launch.
"""

args = (pw_code, structure, protocol)

inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (pw_code, structure, protocol)
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
scf['pw'].pop('structure', None)
scf.pop('clean_workdir', None)
Expand All @@ -338,6 +344,7 @@ def get_builder_from_protocol(
nscf['pw']['parameters']['SYSTEM'].pop('degauss', None)
nscf.pop('clean_workdir', None)

builder = cls.get_builder()
builder.structure = structure
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.scf = scf
Expand Down
54 changes: 25 additions & 29 deletions aiida_quantumespresso/workflows/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
# -*- coding: utf-8 -*-
"""Utilities to manipulate the workflow input protocols."""
import functools
import os
import pathlib
from typing import Optional, Union
import yaml


class ProtocolMixin:
"""Utility class for processes to build input mappings for a given protocol based on a YAML configuration file."""

@classmethod
def get_protocol_filepath(cls):
"""Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
raise NotImplementedError

@classmethod
def get_default_protocol(cls):
"""Return the default protocol for a given workflow class.
:param cls: the workflow class.
:return: the default protocol.
"""
return load_protocol_file(cls)['default_protocol']
return cls._load_protocol_file()['default_protocol']

@classmethod
def get_available_protocols(cls):
Expand All @@ -26,20 +30,25 @@ def get_available_protocols(cls):
:return: dictionary of available protocols, where each key is a protocol and value is another dictionary that
contains at least the key `description` and optionally other keys with supplementary information.
"""
data = load_protocol_file(cls)
data = cls._load_protocol_file()
return {protocol: {'description': values['description']} for protocol, values in data['protocols'].items()}

@classmethod
def get_protocol_inputs(cls, protocol=None, overrides=None):
def get_protocol_inputs(
cls,
protocol: Optional[dict] = None,
overrides: Union[dict, pathlib.Path, None] = None,
) -> dict:
"""Return the inputs for the given workflow class and protocol.
:param cls: the workflow class.
:param protocol: optional specific protocol, if not specified, the default will be used
:param overrides: dictionary of inputs that should override those specified by the protocol. The mapping should
maintain the exact same nesting structure as the input port namespace of the corresponding workflow class.
:param protocol_file_path: Path to the `.yaml` file where the protocols are stored.
:return: mapping of inputs to be used for the workflow class.
"""
data = load_protocol_file(cls)
data = cls._load_protocol_file()
protocol = protocol or data['default_protocol']

try:
Expand All @@ -48,15 +57,24 @@ def get_protocol_inputs(cls, protocol=None, overrides=None):
raise ValueError(
f'`{protocol}` is not a valid protocol. Call ``get_available_protocols`` to show available protocols.'
) from exception

inputs = recursive_merge(data['default_inputs'], protocol_inputs)
inputs.pop('description')

if isinstance(overrides, pathlib.Path):
with overrides.open() as file:
overrides = yaml.safe_load(file)

if overrides:
return recursive_merge(inputs, overrides)

return inputs

@classmethod
def _load_protocol_file(cls):
"""Return the contents of the protocol file for workflow class."""
with cls.get_protocol_filepath().open() as file:
return yaml.safe_load(file)


def recursive_merge(left, right):
"""Recursively merge two dictionaries into a single dictionary.
Expand Down Expand Up @@ -84,28 +102,6 @@ def recursive_merge(left, right):
return merged


def load_protocol_file(cls):
"""Load the protocol file for the given workflow class.
:param cls: the workflow class.
:return: the contents of the protocol file.
"""
from aiida.plugins.entry_point import get_entry_point_from_class

_, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__)
entry_point_name = entry_point.name
parts = entry_point_name.split('.')
parts.pop(0)
filename = f'{parts.pop()}.yaml'
try:
basepath = functools.reduce(os.path.join, parts)
except TypeError:
basepath = '.'

with (pathlib.Path(__file__).resolve().parent / basepath / filename).open() as handle:
return yaml.safe_load(handle)


def get_magnetization_parameters() -> dict:
"""Return the mapping of suggested initial magnetic moments for each element.
Expand Down
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def define(cls, spec):
help='The computed band structure.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'bands.yaml'

@classmethod
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, **kwargs):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
Expand All @@ -124,10 +131,9 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non
sub processes that are called by this workchain.
:return: a process builder instance with all inputs defined ready for launch.
"""
args = (code, structure, protocol)
inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (code, structure, protocol)
relax = PwRelaxWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('relax', None), **kwargs)
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
bands = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('bands', None), **kwargs)
Expand All @@ -142,6 +148,7 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non
bands.pop('kpoints_distance', None)
bands.pop('kpoints_force_parity', None)

builder = cls.get_builder()
builder.structure = structure
builder.relax = relax
builder.scf = scf
Expand Down
9 changes: 8 additions & 1 deletion aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def define(cls, spec):
message='Then ionic minimization cycle converged but the thresholds are exceeded in the final SCF.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'base.yaml'

@classmethod
def get_builder_from_protocol(
cls,
Expand Down Expand Up @@ -158,7 +165,6 @@ def get_builder_from_protocol(
if initial_magnetic_moments is not None and spin_type is not SpinType.COLLINEAR:
raise ValueError(f'`initial_magnetic_moments` is specified but spin type `{spin_type}` is incompatible.')

builder = cls.get_builder()
inputs = cls.get_protocol_inputs(protocol, overrides)

meta_parameters = inputs.pop('meta_parameters')
Expand Down Expand Up @@ -194,6 +200,7 @@ def get_builder_from_protocol(
parameters['SYSTEM']['nspin'] = 2
parameters['SYSTEM']['starting_magnetization'] = starting_magnetization

builder = cls.get_builder()
builder.pw['code'] = code # pylint: disable=no-member
builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure) # pylint: disable=no-member
builder.pw['structure'] = structure # pylint: disable=no-member
Expand Down
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def define(cls, spec):
help='The successfully relaxed structure.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'relax.yaml'

@classmethod
def get_builder_from_protocol(
cls, code, structure, protocol=None, overrides=None, relax_type=RelaxType.POSITIONS_CELL, **kwargs
Expand All @@ -112,10 +119,9 @@ def get_builder_from_protocol(
"""
type_check(relax_type, RelaxType)

args = (code, structure, protocol)
inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (code, structure, protocol)
base = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('base', None), **kwargs)
base_final_scf = PwBaseWorkChain.get_builder_from_protocol(
*args, overrides=inputs.get('base_final_scf', None), **kwargs
Expand Down Expand Up @@ -153,6 +159,7 @@ def get_builder_from_protocol(
if relax_type in (RelaxType.CELL, RelaxType.POSITIONS_CELL):
base.pw.parameters['CELL']['cell_dofree'] = 'all'

builder = cls.get_builder()
builder.base = base
builder.base_final_scf = base_final_scf
builder.structure = structure
Expand Down
3 changes: 2 additions & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
"packaging",
"qe-tools~=2.0rc1",
"xmlschema~=1.2,>=1.2.5",
"numpy"
"numpy",
"importlib_resources"
],
"license": "MIT License",
"name": "aiida_quantumespresso",
Expand Down

0 comments on commit a04a055

Please sign in to comment.