Skip to content

Commit

Permalink
Update typehinting code to use primitives and collections types as ma…
Browse files Browse the repository at this point in the history
…in types over typing variants (#33427)

* Refactor: Add convert_collections_from_typing()

Added  to convert typing module collections to built-ins. This function effectively reverses the operation of the  function.  Includes comprehensive unit tests to verify the correct conversion of various typing collections to their builtin counterparts, including nested structures and type variables.

* Flip paradigm for convert_to_beam_type to be primative and collections-centric

* update comment

* fix clobbered import from merge

* formatting

* fix imports

* address comments

* remove extra import artifacts from merge

---------

Co-authored-by: labs-code-app[bot] <161369871+labs-code-app[bot]@users.noreply.github.com>
  • Loading branch information
jrmccluskey and labs-code-app[bot] authored Jan 7, 2025
1 parent 1b78b67 commit 4a5575c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
107 changes: 84 additions & 23 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# pytype: skip-file

import collections
import collections.abc
import logging
import sys
import types
Expand All @@ -45,7 +46,18 @@
frozenset: typing.FrozenSet,
}

_BUILTINS = [
dict,
list,
tuple,
set,
frozenset,
]

_CONVERTED_COLLECTIONS = [
collections.abc.Iterable,
collections.abc.Iterator,
collections.abc.Generator,
collections.abc.Set,
collections.abc.MutableSet,
collections.abc.Collection,
Expand Down Expand Up @@ -99,14 +111,25 @@ def _match_issubclass(match_against):
return lambda user_type: _safe_issubclass(user_type, match_against)


def _is_primitive(user_type, primitive):
# catch bare primitives
if user_type is primitive:
return True
return getattr(user_type, '__origin__', None) is primitive


def _match_is_primitive(match_against):
return lambda user_type: _is_primitive(user_type, match_against)


def _match_is_exactly_mapping(user_type):
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Mapping
return getattr(user_type, '__origin__', None) is expected_origin


def _match_is_exactly_iterable(user_type):
if user_type is typing.Iterable:
if user_type is typing.Iterable or user_type is collections.abc.Iterable:
return True
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Iterable
Expand Down Expand Up @@ -152,11 +175,13 @@ def _match_is_union(user_type):
return False


def match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set):
def _match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set) or _is_primitive(user_type, set):
return True
elif getattr(user_type, '__origin__', None) is not None:
return _safe_issubclass(user_type.__origin__, collections.abc.Set)
return _safe_issubclass(
user_type.__origin__, collections.abc.Set) or _safe_issubclass(
user_type.__origin__, collections.abc.MutableSet)
else:
return False

Expand Down Expand Up @@ -197,6 +222,36 @@ def convert_builtin_to_typing(typ):
return typ


def convert_typing_to_builtin(typ):
"""Converts a given typing collections type to its builtin counterpart.
Args:
typ: A typing type (e.g., typing.List[int]).
Returns:
type: The corresponding builtin type (e.g., list[int]).
"""
origin = getattr(typ, '__origin__', None)
args = getattr(typ, '__args__', None)
# Typing types return the primitive type as the origin from 3.9 on
if origin not in _BUILTINS:
return typ
# Early return for bare types
if not args:
return origin
if origin is list:
return list[convert_typing_to_builtin(args[0])]
elif origin is dict:
return dict[convert_typing_to_builtin(args[0]),
convert_typing_to_builtin(args[1])]
elif origin is tuple:
return tuple[tuple(convert_typing_to_builtin(args))]
elif origin is set:
return set[convert_typing_to_builtin(args)]
elif origin is frozenset:
return frozenset[convert_typing_to_builtin(args)]


def convert_collections_to_typing(typ):
"""Converts a given collections.abc type to a typing object.
Expand All @@ -216,6 +271,12 @@ def convert_collections_to_typing(typ):
return typ


def is_builtin(typ):
if typ in _BUILTINS:
return True
return getattr(typ, '__origin__', None) in _BUILTINS


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.
Expand All @@ -238,11 +299,8 @@ def convert_to_beam_type(typ):
sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)):
typ = typing.Union[typ]

if isinstance(typ, types.GenericAlias):
typ = convert_builtin_to_typing(typ)

if getattr(typ, '__module__', None) == 'collections.abc':
typ = convert_collections_to_typing(typ)
if getattr(typ, '__module__', None) == 'typing':
typ = convert_typing_to_builtin(typ)

typ_module = getattr(typ, '__module__', None)
if isinstance(typ, typing.TypeVar):
Expand All @@ -267,8 +325,16 @@ def convert_to_beam_type(typ):
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
# Need to pass through WindowedValue class so that it can be converted
# to the correct type constraint in Beam
# This is needed to fix https://github.com/apache/beam/issues/33356
pass

elif (typ_module != 'typing') and (typ_module !=
'collections.abc') and not is_builtin(typ):
# Only translate primitives and types from collections.abc and typing.
return typ
if (typ_module == 'collections.abc' and
typ.__origin__ not in _CONVERTED_COLLECTIONS):
Expand All @@ -285,39 +351,34 @@ def convert_to_beam_type(typ):
_TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any),
_TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Dict),
arity=2,
beam_type=typehints.Dict),
match=_match_is_primitive(dict), arity=2, beam_type=typehints.Dict),
_TypeMapEntry(
match=_match_is_exactly_iterable,
arity=1,
beam_type=typehints.Iterable),
_TypeMapEntry(
match=_match_issubclass(typing.List),
arity=1,
beam_type=typehints.List),
match=_match_is_primitive(list), arity=1, beam_type=typehints.List),
# FrozenSets are a specific instance of a set, so we check this first.
_TypeMapEntry(
match=_match_issubclass(typing.FrozenSet),
match=_match_is_primitive(frozenset),
arity=1,
beam_type=typehints.FrozenSet),
_TypeMapEntry(match=match_is_set, arity=1, beam_type=typehints.Set),
_TypeMapEntry(match=_match_is_set, arity=1, beam_type=typehints.Set),
# NamedTuple is a subclass of Tuple, but it needs special handling.
# We just convert it to Any for now.
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Tuple),
arity=-1,
match=_match_is_primitive(tuple), arity=-1,
beam_type=typehints.Tuple),
_TypeMapEntry(match=_match_is_union, arity=-1, beam_type=typehints.Union),
_TypeMapEntry(
match=_match_issubclass(typing.Generator),
match=_match_issubclass(collections.abc.Generator),
arity=3,
beam_type=typehints.Generator),
_TypeMapEntry(
match=_match_issubclass(typing.Iterator),
match=_match_issubclass(collections.abc.Iterator),
arity=1,
beam_type=typehints.Iterator),
_TypeMapEntry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_beam_types
from apache_beam.typehints.native_type_compatibility import convert_to_typing_type
from apache_beam.typehints.native_type_compatibility import convert_to_typing_types
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
from apache_beam.typehints.native_type_compatibility import is_any

_TestNamedTuple = typing.NamedTuple(
Expand All @@ -43,6 +44,7 @@ class _TestClass(object):


T = typing.TypeVar('T')
U = typing.TypeVar('U')


class _TestGeneric(typing.Generic[T]):
Expand Down Expand Up @@ -140,7 +142,7 @@ def test_convert_to_beam_type_with_builtin_types(self):
(
'builtin nested tuple',
tuple[str, list],
typehints.Tuple[str, typehints.List[typehints.Any]],
typehints.Tuple[str, typehints.List[typehints.TypeVariable('T')]],
)
]

Expand All @@ -159,7 +161,7 @@ def test_convert_to_beam_type_with_collections_types(self):
typehints.Iterable[int]),
(
'collection generator',
collections.abc.Generator[int],
collections.abc.Generator[int, None, None],
typehints.Generator[int]),
(
'collection iterator',
Expand All @@ -177,9 +179,8 @@ def test_convert_to_beam_type_with_collections_types(self):
'mapping not caught',
collections.abc.Mapping[str, int],
collections.abc.Mapping[str, int]),
('set', collections.abc.Set[str], typehints.Set[str]),
('set', collections.abc.Set[int], typehints.Set[int]),
('mutable set', collections.abc.MutableSet[int], typehints.Set[int]),
('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]),
(
'enum mutable set',
collections.abc.MutableSet[_TestEnum],
Expand Down Expand Up @@ -337,6 +338,24 @@ def test_is_any(self):
for expected, typ in test_cases:
self.assertEqual(expected, is_any(typ), msg='%s' % typ)

def test_convert_typing_to_builtin(self):
test_cases = [
('list', typing.List[int],
list[int]), ('dict', typing.Dict[str, int], dict[str, int]),
('tuple', typing.Tuple[str, int], tuple[str, int]),
('set', typing.Set[str], set[str]),
('frozenset', typing.FrozenSet[int], frozenset[int]),
(
'nested',
typing.List[typing.Dict[str, typing.Tuple[int]]],
list[dict[str, tuple[int]]]), ('typevar', typing.List[T], list[T]),
('nested_typevar', typing.Dict[T, typing.List[U]], dict[T, list[U]])
]

for description, typing_type, expected_builtin_type in test_cases:
builtin_type = convert_typing_to_builtin(typing_type)
self.assertEqual(builtin_type, expected_builtin_type, description)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4a5575c

Please sign in to comment.