Skip to content

Commit

Permalink
Provide @sliceable instaed of @DataClass to preserve IDE hints
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 7, 2024
1 parent 4b82a58 commit 2e597de
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 56 deletions.
62 changes: 31 additions & 31 deletions phiml/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,40 @@
Dataclass fields can additionally be specified as being *variable* and *value*.
This affects which attributes are optimized / traced by functions like `phiml.math.jit_compile` or `phiml.math.minimize`.
**Template for custom classes**
```python
from phiml.dataclasses import dataclass, cached_property
from phiml.math import Tensor, Shape, shape
@dataclass(frozen=True)
class MyClass:
# --- Attributes ---
attribute1: Tensor
attribute2: 'MyClass' = None
# --- Additional fields ---
field1: str = 'x'
# --- Special fields declaring attribute types ---
variable_attrs = ('attribute1', 'attribute2')
value_attrs = ()
def __post_init__(self):
assert self.field1 in 'xyz'
@cached_property
def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes
return self.attribute1.shape & shape(self.attribute2)
@cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
def derived_property(self) -> Tensor:
return self.attribute1 + 1
```
**Template for custom classes:**
>>> from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>> # --- Attributes ---
>>> attribute1: Tensor
>>> attribute2: 'MyClass' = None
>>>
>>> # --- Additional fields ---
>>> field1: str = 'x'
>>>
>>> # --- Special fields declaring attribute types ---
>>> variable_attrs = ('attribute1', 'attribute2')
>>> value_attrs = ()
>>>
>>> def __post_init__(self):
>>> assert self.field1 in 'xyz'
>>>
>>> @cached_property
>>> def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes
>>> return self.attribute1.shape & shape(self.attribute2)
>>>
>>> @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>> def derived_property(self) -> Tensor:
>>> return self.attribute1 + 1
"""

from functools import cached_property

from ._dataclasses import dataclass, attributes, replace, getitem
from ._dataclasses import sliceable, attributes, replace, getitem

__all__ = [key for key in globals().keys() if not key.startswith('_')]
40 changes: 15 additions & 25 deletions phiml/dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import inspect
from dataclasses import dataclass
from functools import cached_property
from typing import TypeVar, Callable, Tuple, List, Set, Iterable, Optional, get_origin, get_args, Dict, Sequence

Expand All @@ -12,35 +13,25 @@
PhiMLDataclass = TypeVar("PhiMLDataclass")


def dataclass(cls=None, /, *, getitem=True, dim_attrs=True, keepdims=None, dim_repr=True,
init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=True, match_args=True, kw_only=False, slots=False, weakref_slot=False):
def sliceable(cls=None, /, *, dim_attrs=True, keepdims=None, dim_repr=True):
"""
Convenience decorator for PhiML dataclasses.
This builds a regular dataclass but adds additional options for slicing.
If you don't require the additional features or want to implement them yourself, you may also use `@dataclass` from the `dataclasses` module.
Decorator for frozen dataclasses, adding slicing functionality by defining `__getitem__`.
This enables slicing similar to tensors, gathering and boolean masking.
Args:
getitem: Whether to generate the `__getitem__` method for slice / gather / boolean_mask, depending on the argument.
dim_attrs: Whether to generate `__getattr__` that allows slicing via the syntax `instance.dim[...]` where `dim` is the name of any dim present on `instance`.
keepdims: Which dimensions should be kept with size 1 taking a single slice along them. This will preserve item names.
dim_repr: Whether to replace the default `repr` of a dataclass by a simplified one based on the object's shape.
All other args are passed on to `dataclasses.dataclass`.
Note however, that `frozen` must be `True` which is the default in this decorator.
"""

assert frozen, f"PhiML dataclasses must be frozen."

def wrap(cls):
overridden = [pair[0] for pair in inspect.getmembers(cls, inspect.isfunction)]
dataclasses.dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, match_args=match_args, kw_only=kw_only, slots=slots, weakref_slot=weakref_slot)
assert dataclasses.is_dataclass(cls), f"@sliceable must be used on a @dataclass, i.e. declared above it."
assert cls.__dataclass_params__.frozen, f"@sliceable dataclasses must be frozen. Declare as @dataclass(frozen=True)"
assert attributes(cls), f"PhiML dataclasses must have at least one field storing a Shaped object, such as a Tensor, tree of Tensors or compatible dataclass."
if getitem and '__getitem__' not in overridden:
if not hasattr(cls, '__getitem__'):
def __dataclass_getitem__(obj, item):
return getitem_impl(obj, item, keepdims=keepdims)
return getitem(obj, item, keepdims=keepdims)
cls.__getitem__ = __dataclass_getitem__
if dim_attrs and '__getattr__' not in overridden:
if dim_attrs and not hasattr(cls, '__getattr__'):
def __dataclass_getattr__(obj, name: str):
if name in ('shape', '__shape__', '__all_attrs__', '__variable_attrs__', '__value_attrs__'): # these can cause infinite recursion
raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'")
Expand All @@ -49,13 +40,15 @@ def __dataclass_getattr__(obj, name: str):
else:
raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'")
cls.__getattr__ = __dataclass_getattr__
if dim_repr and '__repr__' not in overridden:
if dim_repr:
def __dataclass_repr__(obj):
try:
obj_shape = shape(obj)
content = shape(obj)
if not content:
content = f"{', '.join([f'{f.name}={getattr(obj, f.name)}' for f in dataclasses.fields(cls)])}"
except BaseException as err:
obj_shape = f"Unknown shape: {type(err).__name__}"
return f"{type(obj).__name__}[{obj_shape}]"
content = f"Unknown shape: {type(err).__name__}"
return f"{type(obj).__name__}[{content}]"
cls.__repr__ = __dataclass_repr__
return cls

Expand Down Expand Up @@ -169,9 +162,6 @@ def __getitem__(self, item):
return new_obj


getitem_impl = getitem


def stack(objs: Sequence, dim):
raise NotImplementedError # stack cached properties

Expand Down

0 comments on commit 2e597de

Please sign in to comment.