diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 6f704b37a969..3f57a573b505 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -20,6 +20,7 @@ # pytype: skip-file import collections +import collections.abc import logging import sys import types @@ -45,7 +46,18 @@ frozenset: typing.FrozenSet, } +_BUILTINS = [ + dict, + list, + tuple, + set, + frozenset, +] + _CONVERTED_COLLECTIONS = [ + collections.abc.Iterable, + collections.abc.Iterator, + collections.abc.Generator, collections.abc.Set, collections.abc.MutableSet, collections.abc.Collection, @@ -99,6 +111,17 @@ def _match_issubclass(match_against): return lambda user_type: _safe_issubclass(user_type, match_against) +def _is_primitive(user_type, primitive): + # catch bare primitives + if user_type is primitive: + return True + return getattr(user_type, '__origin__', None) is primitive + + +def _match_is_primitive(match_against): + return lambda user_type: _is_primitive(user_type, match_against) + + def _match_is_exactly_mapping(user_type): # Avoid unintentionally catching all subtypes (e.g. strings and mappings). expected_origin = collections.abc.Mapping @@ -106,7 +129,7 @@ def _match_is_exactly_mapping(user_type): def _match_is_exactly_iterable(user_type): - if user_type is typing.Iterable: + if user_type is typing.Iterable or user_type is collections.abc.Iterable: return True # Avoid unintentionally catching all subtypes (e.g. strings and mappings). expected_origin = collections.abc.Iterable @@ -152,11 +175,13 @@ def _match_is_union(user_type): return False -def match_is_set(user_type): - if _safe_issubclass(user_type, typing.Set): +def _match_is_set(user_type): + if _safe_issubclass(user_type, typing.Set) or _is_primitive(user_type, set): return True elif getattr(user_type, '__origin__', None) is not None: - return _safe_issubclass(user_type.__origin__, collections.abc.Set) + return _safe_issubclass( + user_type.__origin__, collections.abc.Set) or _safe_issubclass( + user_type.__origin__, collections.abc.MutableSet) else: return False @@ -197,6 +222,36 @@ def convert_builtin_to_typing(typ): return typ +def convert_typing_to_builtin(typ): + """Converts a given typing collections type to its builtin counterpart. + + Args: + typ: A typing type (e.g., typing.List[int]). + + Returns: + type: The corresponding builtin type (e.g., list[int]). + """ + origin = getattr(typ, '__origin__', None) + args = getattr(typ, '__args__', None) + # Typing types return the primitive type as the origin from 3.9 on + if origin not in _BUILTINS: + return typ + # Early return for bare types + if not args: + return origin + if origin is list: + return list[convert_typing_to_builtin(args[0])] + elif origin is dict: + return dict[convert_typing_to_builtin(args[0]), + convert_typing_to_builtin(args[1])] + elif origin is tuple: + return tuple[tuple(convert_typing_to_builtin(args))] + elif origin is set: + return set[convert_typing_to_builtin(args)] + elif origin is frozenset: + return frozenset[convert_typing_to_builtin(args)] + + def convert_collections_to_typing(typ): """Converts a given collections.abc type to a typing object. @@ -216,6 +271,12 @@ def convert_collections_to_typing(typ): return typ +def is_builtin(typ): + if typ in _BUILTINS: + return True + return getattr(typ, '__origin__', None) in _BUILTINS + + def convert_to_beam_type(typ): """Convert a given typing type to a Beam type. @@ -238,11 +299,8 @@ def convert_to_beam_type(typ): sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): typ = typing.Union[typ] - if isinstance(typ, types.GenericAlias): - typ = convert_builtin_to_typing(typ) - - if getattr(typ, '__module__', None) == 'collections.abc': - typ = convert_collections_to_typing(typ) + if getattr(typ, '__module__', None) == 'typing': + typ = convert_typing_to_builtin(typ) typ_module = getattr(typ, '__module__', None) if isinstance(typ, typing.TypeVar): @@ -267,8 +325,16 @@ def convert_to_beam_type(typ): # TODO(https://github.com/apache/beam/issues/20076): Currently unhandled. _LOGGER.info('Converting NewType type hint to Any: "%s"', typ) return typehints.Any - elif (typ_module != 'typing') and (typ_module != 'collections.abc'): - # Only translate types from the typing and collections.abc modules. + elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \ + getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue': + # Need to pass through WindowedValue class so that it can be converted + # to the correct type constraint in Beam + # This is needed to fix https://github.com/apache/beam/issues/33356 + pass + + elif (typ_module != 'typing') and (typ_module != + 'collections.abc') and not is_builtin(typ): + # Only translate primitives and types from collections.abc and typing. return typ if (typ_module == 'collections.abc' and typ.__origin__ not in _CONVERTED_COLLECTIONS): @@ -285,39 +351,34 @@ def convert_to_beam_type(typ): _TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any), _TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any), _TypeMapEntry( - match=_match_issubclass(typing.Dict), - arity=2, - beam_type=typehints.Dict), + match=_match_is_primitive(dict), arity=2, beam_type=typehints.Dict), _TypeMapEntry( match=_match_is_exactly_iterable, arity=1, beam_type=typehints.Iterable), _TypeMapEntry( - match=_match_issubclass(typing.List), - arity=1, - beam_type=typehints.List), + match=_match_is_primitive(list), arity=1, beam_type=typehints.List), # FrozenSets are a specific instance of a set, so we check this first. _TypeMapEntry( - match=_match_issubclass(typing.FrozenSet), + match=_match_is_primitive(frozenset), arity=1, beam_type=typehints.FrozenSet), - _TypeMapEntry(match=match_is_set, arity=1, beam_type=typehints.Set), + _TypeMapEntry(match=_match_is_set, arity=1, beam_type=typehints.Set), # NamedTuple is a subclass of Tuple, but it needs special handling. # We just convert it to Any for now. # This MUST appear before the entry for the normal Tuple. _TypeMapEntry( match=match_is_named_tuple, arity=0, beam_type=typehints.Any), _TypeMapEntry( - match=_match_issubclass(typing.Tuple), - arity=-1, + match=_match_is_primitive(tuple), arity=-1, beam_type=typehints.Tuple), _TypeMapEntry(match=_match_is_union, arity=-1, beam_type=typehints.Union), _TypeMapEntry( - match=_match_issubclass(typing.Generator), + match=_match_issubclass(collections.abc.Generator), arity=3, beam_type=typehints.Generator), _TypeMapEntry( - match=_match_issubclass(typing.Iterator), + match=_match_issubclass(collections.abc.Iterator), arity=1, beam_type=typehints.Iterator), _TypeMapEntry( diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index ae8e1a0b2906..15b5da99fb0c 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -30,6 +30,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_types from apache_beam.typehints.native_type_compatibility import convert_to_typing_type from apache_beam.typehints.native_type_compatibility import convert_to_typing_types +from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin from apache_beam.typehints.native_type_compatibility import is_any _TestNamedTuple = typing.NamedTuple( @@ -43,6 +44,7 @@ class _TestClass(object): T = typing.TypeVar('T') +U = typing.TypeVar('U') class _TestGeneric(typing.Generic[T]): @@ -140,7 +142,7 @@ def test_convert_to_beam_type_with_builtin_types(self): ( 'builtin nested tuple', tuple[str, list], - typehints.Tuple[str, typehints.List[typehints.Any]], + typehints.Tuple[str, typehints.List[typehints.TypeVariable('T')]], ) ] @@ -159,7 +161,7 @@ def test_convert_to_beam_type_with_collections_types(self): typehints.Iterable[int]), ( 'collection generator', - collections.abc.Generator[int], + collections.abc.Generator[int, None, None], typehints.Generator[int]), ( 'collection iterator', @@ -177,9 +179,8 @@ def test_convert_to_beam_type_with_collections_types(self): 'mapping not caught', collections.abc.Mapping[str, int], collections.abc.Mapping[str, int]), - ('set', collections.abc.Set[str], typehints.Set[str]), + ('set', collections.abc.Set[int], typehints.Set[int]), ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), - ('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]), ( 'enum mutable set', collections.abc.MutableSet[_TestEnum], @@ -337,6 +338,24 @@ def test_is_any(self): for expected, typ in test_cases: self.assertEqual(expected, is_any(typ), msg='%s' % typ) + def test_convert_typing_to_builtin(self): + test_cases = [ + ('list', typing.List[int], + list[int]), ('dict', typing.Dict[str, int], dict[str, int]), + ('tuple', typing.Tuple[str, int], tuple[str, int]), + ('set', typing.Set[str], set[str]), + ('frozenset', typing.FrozenSet[int], frozenset[int]), + ( + 'nested', + typing.List[typing.Dict[str, typing.Tuple[int]]], + list[dict[str, tuple[int]]]), ('typevar', typing.List[T], list[T]), + ('nested_typevar', typing.Dict[T, typing.List[U]], dict[T, list[U]]) + ] + + for description, typing_type, expected_builtin_type in test_cases: + builtin_type = convert_typing_to_builtin(typing_type) + self.assertEqual(builtin_type, expected_builtin_type, description) + if __name__ == '__main__': unittest.main()