From 9503884d492662e9ab6bcd40ef7c166920bd6592 Mon Sep 17 00:00:00 2001 From: AngRodrigues Date: Thu, 28 Nov 2024 15:23:46 +1100 Subject: [PATCH] fix issue #155 (simple fix) --- map2loop/project.py | 17 ++++++++++++++++- map2loop/sampler.py | 4 ++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index abd70ced..61213080 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -356,7 +356,22 @@ def set_sampler(self, datatype: Datatype, sampler: Sampler): sampler (Sampler): The sampler to use """ - ## does the enum print the number or the label? + allowed_samplers = { + Datatype.STRUCTURE: SamplerDecimator, + Datatype.GEOLOGY: SamplerSpacing, + Datatype.FAULT: SamplerSpacing, + Datatype.FOLD: SamplerSpacing, + Datatype.DTM: SamplerSpacing, + } + + # Check for wrong sampler + if datatype in allowed_samplers: + allowed_sampler_type = allowed_samplers[datatype] + if not isinstance(sampler, allowed_sampler_type): + raise ValueError( + f"Got wrong argument for this datatype: {type(sampler).__name__}, please use {allowed_sampler_type.__name__} instead" + ) + logger.info(f"Setting sampler for {datatype} to {sampler.sampler_label}") self.samplers[datatype] = sampler diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..2273de29 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -9,7 +9,7 @@ import pandas import shapely import numpy -from typing import Optional +from typing import Optional, Union class Sampler(ABC): @@ -60,7 +60,7 @@ class SamplerDecimator(Sampler): """ @beartype.beartype - def __init__(self, decimation: int = 1): + def __init__(self, decimation: Union[int, float] = 1): """ Initialiser for decimator sampler