From 2e597de279a2335e45cdee969c079846ac60916b Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sat, 7 Dec 2024 13:09:26 +0100 Subject: [PATCH] Provide @sliceable instaed of @dataclass to preserve IDE hints --- phiml/dataclasses/__init__.py | 62 +++++++++++++++---------------- phiml/dataclasses/_dataclasses.py | 40 ++++++++------------ 2 files changed, 46 insertions(+), 56 deletions(-) diff --git a/phiml/dataclasses/__init__.py b/phiml/dataclasses/__init__.py index d2ae87e..9bf2527 100644 --- a/phiml/dataclasses/__init__.py +++ b/phiml/dataclasses/__init__.py @@ -10,40 +10,40 @@ Dataclass fields can additionally be specified as being *variable* and *value*. This affects which attributes are optimized / traced by functions like `phiml.math.jit_compile` or `phiml.math.minimize`. -**Template for custom classes** -```python - -from phiml.dataclasses import dataclass, cached_property -from phiml.math import Tensor, Shape, shape - -@dataclass(frozen=True) -class MyClass: - # --- Attributes --- - attribute1: Tensor - attribute2: 'MyClass' = None - - # --- Additional fields --- - field1: str = 'x' - - # --- Special fields declaring attribute types --- - variable_attrs = ('attribute1', 'attribute2') - value_attrs = () - - def __post_init__(self): - assert self.field1 in 'xyz' - - @cached_property - def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes - return self.attribute1.shape & shape(self.attribute2) - - @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code) - def derived_property(self) -> Tensor: - return self.attribute1 + 1 -``` +**Template for custom classes:** + +>>> from dataclasses import dataclass +>>> from phiml.dataclasses import sliceable, cached_property +>>> from phiml.math import Tensor, Shape, shape +>>> +>>> @sliceable +>>> @dataclass(frozen=True) +>>> class MyClass: +>>> # --- Attributes --- +>>> attribute1: Tensor +>>> attribute2: 'MyClass' = None +>>> +>>> # --- Additional fields --- +>>> field1: str = 'x' +>>> +>>> # --- Special fields declaring attribute types --- +>>> variable_attrs = ('attribute1', 'attribute2') +>>> value_attrs = () +>>> +>>> def __post_init__(self): +>>> assert self.field1 in 'xyz' +>>> +>>> @cached_property +>>> def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes +>>> return self.attribute1.shape & shape(self.attribute2) +>>> +>>> @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code) +>>> def derived_property(self) -> Tensor: +>>> return self.attribute1 + 1 """ from functools import cached_property -from ._dataclasses import dataclass, attributes, replace, getitem +from ._dataclasses import sliceable, attributes, replace, getitem __all__ = [key for key in globals().keys() if not key.startswith('_')] diff --git a/phiml/dataclasses/_dataclasses.py b/phiml/dataclasses/_dataclasses.py index 277a4c2..14b0800 100644 --- a/phiml/dataclasses/_dataclasses.py +++ b/phiml/dataclasses/_dataclasses.py @@ -1,6 +1,7 @@ import collections import dataclasses import inspect +from dataclasses import dataclass from functools import cached_property from typing import TypeVar, Callable, Tuple, List, Set, Iterable, Optional, get_origin, get_args, Dict, Sequence @@ -12,35 +13,25 @@ PhiMLDataclass = TypeVar("PhiMLDataclass") -def dataclass(cls=None, /, *, getitem=True, dim_attrs=True, keepdims=None, dim_repr=True, - init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=True, match_args=True, kw_only=False, slots=False, weakref_slot=False): +def sliceable(cls=None, /, *, dim_attrs=True, keepdims=None, dim_repr=True): """ - Convenience decorator for PhiML dataclasses. - This builds a regular dataclass but adds additional options for slicing. - - If you don't require the additional features or want to implement them yourself, you may also use `@dataclass` from the `dataclasses` module. + Decorator for frozen dataclasses, adding slicing functionality by defining `__getitem__`. + This enables slicing similar to tensors, gathering and boolean masking. Args: - getitem: Whether to generate the `__getitem__` method for slice / gather / boolean_mask, depending on the argument. dim_attrs: Whether to generate `__getattr__` that allows slicing via the syntax `instance.dim[...]` where `dim` is the name of any dim present on `instance`. keepdims: Which dimensions should be kept with size 1 taking a single slice along them. This will preserve item names. dim_repr: Whether to replace the default `repr` of a dataclass by a simplified one based on the object's shape. - - All other args are passed on to `dataclasses.dataclass`. - Note however, that `frozen` must be `True` which is the default in this decorator. """ - - assert frozen, f"PhiML dataclasses must be frozen." - def wrap(cls): - overridden = [pair[0] for pair in inspect.getmembers(cls, inspect.isfunction)] - dataclasses.dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, match_args=match_args, kw_only=kw_only, slots=slots, weakref_slot=weakref_slot) + assert dataclasses.is_dataclass(cls), f"@sliceable must be used on a @dataclass, i.e. declared above it." + assert cls.__dataclass_params__.frozen, f"@sliceable dataclasses must be frozen. Declare as @dataclass(frozen=True)" assert attributes(cls), f"PhiML dataclasses must have at least one field storing a Shaped object, such as a Tensor, tree of Tensors or compatible dataclass." - if getitem and '__getitem__' not in overridden: + if not hasattr(cls, '__getitem__'): def __dataclass_getitem__(obj, item): - return getitem_impl(obj, item, keepdims=keepdims) + return getitem(obj, item, keepdims=keepdims) cls.__getitem__ = __dataclass_getitem__ - if dim_attrs and '__getattr__' not in overridden: + if dim_attrs and not hasattr(cls, '__getattr__'): def __dataclass_getattr__(obj, name: str): if name in ('shape', '__shape__', '__all_attrs__', '__variable_attrs__', '__value_attrs__'): # these can cause infinite recursion raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'") @@ -49,13 +40,15 @@ def __dataclass_getattr__(obj, name: str): else: raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'") cls.__getattr__ = __dataclass_getattr__ - if dim_repr and '__repr__' not in overridden: + if dim_repr: def __dataclass_repr__(obj): try: - obj_shape = shape(obj) + content = shape(obj) + if not content: + content = f"{', '.join([f'{f.name}={getattr(obj, f.name)}' for f in dataclasses.fields(cls)])}" except BaseException as err: - obj_shape = f"Unknown shape: {type(err).__name__}" - return f"{type(obj).__name__}[{obj_shape}]" + content = f"Unknown shape: {type(err).__name__}" + return f"{type(obj).__name__}[{content}]" cls.__repr__ = __dataclass_repr__ return cls @@ -169,9 +162,6 @@ def __getitem__(self, item): return new_obj -getitem_impl = getitem - - def stack(objs: Sequence, dim): raise NotImplementedError # stack cached properties