diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index e4b55f6..c6c8934 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -3,9 +3,7 @@ import traceback import warnings from contextlib import contextmanager -import typing -from types import EllipsisType -from typing import Union, TypeVar, Sequence, Any +from typing import Union, TypeVar, Sequence, Any, Literal from dataclasses import dataclass from typing import Tuple, Callable, List @@ -2447,7 +2445,7 @@ def reshaped_native(value: Tensor, def process_groups_for(shape: Shape, groups: Any) -> List[Shape]: if callable(groups): return list(groups(shape)) - def process_group(g) -> Union[Shape, EllipsisType]: + def process_group(g): # returns Shape or Ellipsis if g is None or (isinstance(g, tuple) and len(g) == 0): return EMPTY_SHAPE if isinstance(g, SHAPE_TYPES):