From c17c5b79ec2ad0d83bdf9356b67c2663186215ad Mon Sep 17 00:00:00 2001 From: yukinarit Date: Mon, 3 Jun 2024 08:49:22 +0900 Subject: [PATCH] Cache typing functions to speed up @serde codegen --- serde/compat.py | 78 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 12 deletions(-) diff --git a/serde/compat.py b/serde/compat.py index b841d6dd..b94c9df2 100644 --- a/serde/compat.py +++ b/serde/compat.py @@ -17,10 +17,10 @@ from collections import defaultdict from collections.abc import Iterator from dataclasses import is_dataclass -from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union +from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable import typing_inspect -from typing_extensions import TypeGuard +from typing_extensions import TypeGuard, ParamSpec from .sqlalchemy import is_sqlalchemy_inspectable @@ -29,7 +29,7 @@ def get_np_origin(tp: type[Any]) -> Optional[Any]: return None -def get_np_args(tp: Any) -> tuple[Any, ...]: +def get_np_args(tp: type[Any]) -> tuple[Any, ...]: return () @@ -93,6 +93,32 @@ class SerdeSkip(Exception): """ +def is_hashable(typ: Any) -> TypeGuard[Hashable]: + """ + Test is an object hashable + """ + try: + hash(typ) + except TypeError: + return False + return True + + +P = ParamSpec("P") + + +def cache(f: Callable[P, T]) -> Callable[P, T]: + """ + Wrapper for `functools.cache` to avoid `Hashable` related type errors. + """ + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + return f(*args, **kwargs) + + return functools.cache(wrapper) # type: ignore + + +@cache def get_origin(typ: Any) -> Optional[Any]: """ Provide `get_origin` that works in all python versions. @@ -100,13 +126,15 @@ def get_origin(typ: Any) -> Optional[Any]: return typing.get_origin(typ) or get_np_origin(typ) -def get_args(typ: Any) -> tuple[Any, ...]: +@cache +def get_args(typ: type[Any]) -> tuple[Any, ...]: """ Provide `get_args` that works in all python versions. """ return typing.get_args(typ) or get_np_args(typ) +@cache def typename(typ: Any, with_typing_module: bool = False) -> str: """ >>> from typing import Any @@ -268,16 +296,16 @@ def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ig TypeLike = Union[type[Any], typing.Any] -def iter_types(cls: TypeLike) -> list[TypeLike]: +def iter_types(cls: type[Any]) -> list[type[Any]]: """ Iterate field types recursively. The correct return type is `Iterator[Union[Type, typing._specialform]], but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead. """ - lst: set[TypeLike] = set() + lst: set[type[Any]] = set() - def recursive(cls: TypeLike) -> None: + def recursive(cls: type[Any]) -> None: if cls in lst: return @@ -288,12 +316,12 @@ def recursive(cls: TypeLike) -> None: elif isinstance(cls, str): lst.add(cls) elif is_opt(cls): - lst.add(Optional) + lst.add(Optional) # type: ignore args = type_args(cls) if args: recursive(args[0]) elif is_union(cls): - lst.add(Union) + lst.add(Union) # type: ignore for arg in type_args(cls): recursive(arg) elif is_list(cls): @@ -366,13 +394,13 @@ def recursive(cls: TypeLike) -> None: return list(lst) -def iter_literals(cls: TypeLike) -> list[TypeLike]: +def iter_literals(cls: type[Any]) -> list[TypeLike]: """ Iterate over all literals that are used in the dataclass """ - lst: set[TypeLike] = set() + lst: set[type[Any]] = set() - def recursive(cls: TypeLike) -> None: + def recursive(cls: type[Any]) -> None: if cls in lst: return @@ -406,6 +434,7 @@ def recursive(cls: TypeLike) -> None: return list(lst) +@cache def is_union(typ: Any) -> bool: """ Test if the type is `typing.Union`. @@ -433,6 +462,7 @@ def is_union(typ: Any) -> bool: return typing_inspect.is_union_type(typ) # type: ignore +@cache def is_opt(typ: Any) -> bool: """ Test if the type is `typing.Optional`. @@ -469,6 +499,7 @@ def is_opt(typ: Any) -> bool: return typ is Optional +@cache def is_bare_opt(typ: Any) -> bool: """ Test if the type is `typing.Optional` without type args. @@ -482,6 +513,7 @@ def is_bare_opt(typ: Any) -> bool: return not type_args(typ) and typ is Optional +@cache def is_list(typ: type[Any]) -> bool: """ Test if the type is `list`. @@ -497,6 +529,7 @@ def is_list(typ: type[Any]) -> bool: return typ is list +@cache def is_bare_list(typ: type[Any]) -> bool: """ Test if the type is `list` without type args. @@ -509,6 +542,7 @@ def is_bare_list(typ: type[Any]) -> bool: return typ is list +@cache def is_tuple(typ: Any) -> bool: """ Test if the type is tuple. @@ -519,6 +553,7 @@ def is_tuple(typ: Any) -> bool: return typ is tuple +@cache def is_bare_tuple(typ: type[Any]) -> bool: """ Test if the type is tuple without type args. @@ -531,6 +566,7 @@ def is_bare_tuple(typ: type[Any]) -> bool: return typ is tuple +@cache def is_variable_tuple(typ: type[Any]) -> bool: """ Test if the type is a variable length of tuple tuple[T, ...]`. @@ -547,6 +583,7 @@ def is_variable_tuple(typ: type[Any]) -> bool: return istuple and len(args) == 2 and is_ellipsis(args[1]) +@cache def is_set(typ: type[Any]) -> bool: """ Test if the type is `set` or `frozenset`. @@ -564,6 +601,7 @@ def is_set(typ: type[Any]) -> bool: return typ in (set, frozenset) +@cache def is_bare_set(typ: type[Any]) -> bool: """ Test if the type is `set` without type args. @@ -576,6 +614,7 @@ def is_bare_set(typ: type[Any]) -> bool: return typ in (set, frozenset) +@cache def is_frozen_set(typ: type[Any]) -> bool: """ Test if the type is `frozenset`. @@ -591,6 +630,7 @@ def is_frozen_set(typ: type[Any]) -> bool: return typ is frozenset +@cache def is_dict(typ: type[Any]) -> bool: """ Test if the type is dict. @@ -608,6 +648,7 @@ def is_dict(typ: type[Any]) -> bool: return typ in (dict, defaultdict) +@cache def is_bare_dict(typ: type[Any]) -> bool: """ Test if the type is `dict` without type args. @@ -620,6 +661,7 @@ def is_bare_dict(typ: type[Any]) -> bool: return typ is dict +@cache def is_default_dict(typ: type[Any]) -> bool: """ Test if the type is `defaultdict`. @@ -635,6 +677,7 @@ def is_default_dict(typ: type[Any]) -> bool: return typ is defaultdict +@cache def is_none(typ: type[Any]) -> bool: """ >>> is_none(int) @@ -650,6 +693,7 @@ def is_none(typ: type[Any]) -> bool: PRIMITIVES = [int, float, bool, str] +@cache def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: """ Test if the type is `enum.Enum`. @@ -660,6 +704,7 @@ def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]: return isinstance(typ, enum.Enum) +@cache def is_primitive_subclass(typ: type[Any]) -> bool: """ Test if the type is a subclass of primitive type. @@ -674,6 +719,7 @@ def is_primitive_subclass(typ: type[Any]) -> bool: return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ) +@cache def is_primitive(typ: Union[type[Any], NewType]) -> bool: """ Test if the type is primitive. @@ -691,6 +737,7 @@ def is_primitive(typ: Union[type[Any], NewType]) -> bool: return is_new_type_primitive(typ) +@cache def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool: """ Test if the type is a NewType of primitives. @@ -702,10 +749,12 @@ def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool: return any(isinstance(typ, ty) for ty in PRIMITIVES) +@cache def has_generic_base(typ: Any) -> bool: return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ()) +@cache def is_generic(typ: Any) -> bool: """ Test if the type is derived from `typing.Generic`. @@ -722,6 +771,7 @@ def is_generic(typ: Any) -> bool: return origin is not None and has_generic_base(origin) +@cache def is_class_var(typ: type[Any]) -> bool: """ Test if the type is `typing.ClassVar`. @@ -734,6 +784,7 @@ def is_class_var(typ: type[Any]) -> bool: return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore +@cache def is_literal(typ: type[Any]) -> bool: """ Test if the type is derived from `typing.Literal`. @@ -750,6 +801,7 @@ def is_literal(typ: type[Any]) -> bool: return origin is not None and origin is typing.Literal +@cache def is_any(typ: type[Any]) -> bool: """ Test if the type is `typing.Any`. @@ -757,6 +809,7 @@ def is_any(typ: type[Any]) -> bool: return typ is Any +@cache def is_str_serializable(typ: type[Any]) -> bool: """ Test if the type is serializable to `str`. @@ -787,6 +840,7 @@ def is_ellipsis(typ: Any) -> bool: return typ is Ellipsis +@cache def get_type_var_names(cls: type[Any]) -> Optional[list[str]]: """ Get type argument names of a generic class.