From dde633b16413dc78d99981d55d006b75e3bda88b Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 16 Dec 2024 21:47:45 +0100 Subject: [PATCH] Fix save/load for new dataclasses --- phiml/math/_tensors.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 04fa4d7..5afa059 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -2044,11 +2044,15 @@ def attr_paths_from_container(obj: PhiTreeNodeType, attr_type: Callable, root: s return sum([attr_paths_from_container(v, attr_type, f'{root}[{k}]') for k, v in obj.items()], []) elif isinstance(obj, Tensor): raise RuntimeError("Tensor found in container. This should have been set to None by disassemble_tree()") - elif isinstance(obj, PhiTreeNode): + elif dataclasses.is_dataclass(obj): + from ..dataclasses._dataclasses import DataclassTreeNode + if isinstance(obj, DataclassTreeNode): + assert attr_type == obj.attr_type + return sum([attr_paths_from_container(v, attr_type, f'{root}.{k}') for k, v in obj.extracted.items()], []) + if isinstance(obj, PhiTreeNode): attributes = attr_type(obj) return sum([attr_paths_from_container(getattr(obj, k), attr_type, f'{root}.{k}') for k in attributes], []) - else: - return [] + return [] def cached(t: TensorOrTree) -> TensorOrTree: