Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PwBaseWorkChain: make magnetism from overrides absolute #731

Merged
merged 3 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def get_builder_from_protocol(
:param electronic_type: indicate the electronic character of the system through ``ElectronicType`` instance.
:param spin_type: indicate the spin polarization type to use through a ``SpinType`` instance.
:param initial_magnetic_moments: optional dictionary that maps the initial magnetic moment of each kind to a
desired value for a spin polarized calculation. Note that for ``spin_type == SpinType.COLLINEAR`` an initial
guess for the magnetic moment is automatically set in case this argument is not provided.
desired value for a spin polarized calculation. Note that this takes precedence over any
``starting_magnetization`` provided in the ``overrides``, and that for ``spin_type == SpinType.COLLINEAR``
an initial guess for the magnetic moment is automatically set in case neither is provided.
:return: a process builder instance with all inputs defined ready for launch.
"""
from aiida_quantumespresso.workflows.protocols.utils import get_starting_magnetization
Expand Down Expand Up @@ -205,10 +206,10 @@ def get_builder_from_protocol(
parameters['SYSTEM'].pop('smearing')

if spin_type is SpinType.COLLINEAR:
starting_magnetization = get_starting_magnetization(structure, pseudo_family, initial_magnetic_moments)

parameters['SYSTEM']['nspin'] = 2
parameters['SYSTEM']['starting_magnetization'] = starting_magnetization
if 'starting_magnetization' not in parameters['SYSTEM'] or initial_magnetic_moments is not None:
starting_magnetization = get_starting_magnetization(structure, pseudo_family, initial_magnetic_moments)
parameters['SYSTEM']['starting_magnetization'] = starting_magnetization

# pylint: disable=no-member
builder = cls.get_builder()
Expand Down
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fixture_localhost(aiida_localhost):

@pytest.fixture
def fixture_code(fixture_localhost):
"""Return a `Code` instance configured to run calculations of given entry point on localhost `Computer`."""
"""Return a ``Code`` instance configured to run calculations of given entry point on localhost ``Computer``."""

def _fixture_code(entry_point_name):
from aiida.common import exceptions
Expand Down Expand Up @@ -297,10 +297,13 @@ def _generate_upf_data(element):

@pytest.fixture
def generate_structure():
"""Return a `StructureData` representing bulk silicon."""
"""Return a ``StructureData`` representing either bulk silicon or a water molecule."""

def _generate_structure(structure_id='silicon'):
"""Return a `StructureData` representing bulk silicon or a snapshot of a single water molecule dynamics."""
"""Return a ``StructureData`` representing bulk silicon or a snapshot of a single water molecule dynamics.

:param structure_id: identifies the ``StructureData`` you want to generate. Either 'silicon' or 'water'.
"""
from aiida.orm import StructureData

if structure_id == 'silicon':
Expand Down
71 changes: 61 additions & 10 deletions tests/workflows/protocols/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_get_default_protocol():
def test_default(fixture_code, generate_structure, data_regression, serialize_builder):
"""Test ``PwBaseWorkChain.get_builder_from_protocol`` for the default protocol."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')
sphuber marked this conversation as resolved.
Show resolved Hide resolved
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure)

assert isinstance(builder, ProcessBuilder)
Expand All @@ -33,7 +33,7 @@ def test_default(fixture_code, generate_structure, data_regression, serialize_bu
def test_electronic_type(fixture_code, generate_structure):
"""Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``electronic_type`` keyword."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

with pytest.raises(NotImplementedError):
for electronic_type in [ElectronicType.AUTOMATIC]:
Expand All @@ -50,7 +50,12 @@ def test_electronic_type(fixture_code, generate_structure):
def test_spin_type(fixture_code, generate_structure):
"""Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``spin_type`` keyword."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

# Test specifying no magnetic inputs
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure)
assert 'starting_magnetization' not in builder.pw.parameters['SYSTEM']
assert 'nspin' not in builder.pw.parameters['SYSTEM']

with pytest.raises(NotImplementedError):
for spin_type in [SpinType.NON_COLLINEAR, SpinType.SPIN_ORBIT]:
Expand All @@ -67,7 +72,7 @@ def test_spin_type(fixture_code, generate_structure):
def test_initial_magnetic_moments_invalid(fixture_code, generate_structure, initial_magnetic_moments):
"""Test ``PwBaseWorkChain.get_builder_from_protocol`` with invalid ``initial_magnetic_moments`` keyword."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

with pytest.raises(
ValueError, match=r'`initial_magnetic_moments` is specified but spin type `.*` is incompatible.'
Expand All @@ -83,22 +88,68 @@ def test_initial_magnetic_moments_invalid(fixture_code, generate_structure, init
def test_initial_magnetic_moments(fixture_code, generate_structure):
"""Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``initial_magnetic_moments`` keyword."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

initial_magnetic_moments = {'Si': 1.0}
builder = PwBaseWorkChain.get_builder_from_protocol(
code, structure, initial_magnetic_moments=initial_magnetic_moments, spin_type=SpinType.COLLINEAR
)
parameters = builder.pw.parameters.get_dict()

assert parameters['SYSTEM']['nspin'] == 2
assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.25}


def test_magnetization_overrides(fixture_code, generate_structure):
"""Test magnetization ``overrides`` for the ``PwBaseWorkChain.get_builder_from_protocol`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure('silicon')
initial_magnetic_moments = {'Si': 1.0}
initial_starting_magnetization = {'Si': 0.5}
overrides = {'pw': {'parameters': {'SYSTEM': {'starting_magnetization': initial_starting_magnetization}}}}

# Test specifying `starting_magnetization` via the `overrides`
builder = PwBaseWorkChain.get_builder_from_protocol(
code, structure, overrides=overrides, spin_type=SpinType.COLLINEAR
)
assert builder.pw.parameters['SYSTEM']['starting_magnetization'] == initial_starting_magnetization
assert builder.pw.parameters['SYSTEM']['nspin'] == 2

# Test that specifying `initial_magnetic_moments` overrides the `overrides`
builder = PwBaseWorkChain.get_builder_from_protocol(
code,
structure,
overrides=overrides,
spin_type=SpinType.COLLINEAR,
initial_magnetic_moments=initial_magnetic_moments
)
assert builder.pw.parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.25}
assert builder.pw.parameters['SYSTEM']['nspin'] == 2


def test_parameter_overrides(fixture_code, generate_structure):
"""Test specifying parameter ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure('silicon')

overrides = {'pw': {'parameters': {'SYSTEM': {'nbnd': 123}}}}
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
assert builder.pw.parameters['SYSTEM']['nbnd'] == 123


def test_settings_overrides(fixture_code, generate_structure):
"""Test specifying settings ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure('silicon')

overrides = {'pw': {'settings': {'cmdline': ['--kickass-mode']}}}
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
assert builder.pw.settings['cmdline'] == ['--kickass-mode']


def test_metadata_overrides(fixture_code, generate_structure):
"""Test that pw metadata is correctly passed through overrides."""
"""Test specifying metadata ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

overrides = {'pw': {'metadata': {'options': {'resources': {'num_machines': 1e90}, 'max_wallclock_seconds': 1}}}}
builder = PwBaseWorkChain.get_builder_from_protocol(
Expand All @@ -113,9 +164,9 @@ def test_metadata_overrides(fixture_code, generate_structure):


def test_parallelization_overrides(fixture_code, generate_structure):
"""Test that pw parallelization settings are correctly passed through overrides."""
"""Test specifying parallelization ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
structure = generate_structure('silicon')

overrides = {'pw': {'parallelization': {'npool': 4, 'ndiag': 12}}}
builder = PwBaseWorkChain.get_builder_from_protocol(
Expand Down