From 38f99e0cd11170431fdaad144935a5959bcd47d4 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 9 Oct 2024 03:54:43 +0200 Subject: [PATCH] fix(enums): ensure __eq__ gives a numpy array (#1267) --- openfisca_core/indexed_enums/enum_array.py | 6 ++++-- openfisca_core/indexed_enums/types.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 06fc1fbc9..2e9ebf148 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -141,8 +141,10 @@ def __eq__(self, other: object) -> bool: """ if other.__class__.__name__ is self.possible_values.__name__: return self.view(numpy.ndarray) == other.index - - return self.view(numpy.ndarray) == other + is_eq = self.view(numpy.ndarray) == other + if isinstance(is_eq, numpy.ndarray): + return is_eq + return numpy.array([is_eq], dtype=t.BoolDType) def __ne__(self, other: object) -> bool: """Inequality. diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py index a16b03750..ffc2cc9f2 100644 --- a/openfisca_core/indexed_enums/types.py +++ b/openfisca_core/indexed_enums/types.py @@ -4,6 +4,7 @@ from openfisca_core.types import ( Array, ArrayLike, + DTypeBool as BoolDType, DTypeEnum as EnumDType, DTypeGeneric as AnyDType, DTypeInt as IntDType, @@ -49,6 +50,7 @@ __all__ = [ "Array", "ArrayLike", + "BoolDType", "DTypeLike", "Enum", "EnumArray",