diff --git a/map2loop/project.py b/map2loop/project.py index 58f6b5af..c7a1f2db 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -357,6 +357,21 @@ def set_sampler(self, datatype: Datatype, sampler: Sampler): sampler (Sampler): The sampler to use """ + allowed_samplers = { + Datatype.STRUCTURE: SamplerDecimator, + Datatype.GEOLOGY: SamplerSpacing, + Datatype.FAULT: SamplerSpacing, + Datatype.FOLD: SamplerSpacing, + Datatype.DTM: SamplerSpacing, + } + + # Check for wrong sampler + if datatype in allowed_samplers: + allowed_sampler_type = allowed_samplers[datatype] + if not isinstance(sampler, allowed_sampler_type): + raise ValueError( + f"Got wrong argument for this datatype: {type(sampler).__name__}, please use {allowed_sampler_type.__name__} instead" + ) ## does the enum print the number or the label? logger.info(f"Setting sampler for {datatype} to {sampler.sampler_label}") self.samplers[datatype] = sampler diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..77950b13 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -9,7 +9,7 @@ import pandas import shapely import numpy -from typing import Optional +from typing import Optional, Union class Sampler(ABC):