Skip to content

Commit

Permalink
Fix save/load for new dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 16, 2024
1 parent cd20034 commit dde633b
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,11 +2044,15 @@ def attr_paths_from_container(obj: PhiTreeNodeType, attr_type: Callable, root: s
return sum([attr_paths_from_container(v, attr_type, f'{root}[{k}]') for k, v in obj.items()], [])
elif isinstance(obj, Tensor):
raise RuntimeError("Tensor found in container. This should have been set to None by disassemble_tree()")
elif isinstance(obj, PhiTreeNode):
elif dataclasses.is_dataclass(obj):
from ..dataclasses._dataclasses import DataclassTreeNode
if isinstance(obj, DataclassTreeNode):
assert attr_type == obj.attr_type
return sum([attr_paths_from_container(v, attr_type, f'{root}.{k}') for k, v in obj.extracted.items()], [])
if isinstance(obj, PhiTreeNode):
attributes = attr_type(obj)
return sum([attr_paths_from_container(getattr(obj, k), attr_type, f'{root}.{k}') for k in attributes], [])
else:
return []
return []


def cached(t: TensorOrTree) -> TensorOrTree:
Expand Down

0 comments on commit dde633b

Please sign in to comment.