diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 9c5ac2a8..3b9e6a14 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1951,6 +1951,57 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable return obj +def attr_paths(obj: PhiTreeNodeType, attr_type: Callable, root: str) -> List[str]: + if obj is None: + return [] + elif isinstance(obj, Tensor): + return [root] + elif isinstance(obj, (tuple, list)): + paths = [] + for i, item in enumerate(obj): + path = attr_paths(item, attr_type, f'{root}[{i}]') + paths.extend(path) + return paths + elif isinstance(obj, dict): + paths = [] + for name, item in obj.items(): + path = attr_paths(item, attr_type, f'{root}[{name}]') + paths.extend(path) + return paths + elif isinstance(obj, PhiTreeNode): + attributes = attr_type(obj) + paths = [] + for attr in attributes: + path = attr_paths(getattr(obj, attr), attr_type, f'{root}.{attr}') + paths.extend(path) + return paths + else: # native tensor? + try: + return [] if choose_backend(obj) == OBJECTS else [root] + except NoBackendFound: + return [] + + +def attr_paths_from_container(obj: PhiTreeNodeType, attr_type: Callable, root: str) -> List[str]: + if isinstance(obj, str) and obj == MISSING_TENSOR: + return [] + elif isinstance(obj, str) and obj == NATIVE_TENSOR: + return [root] + elif obj is None: + return [root] + elif isinstance(obj, (tuple, list)): + return sum([attr_paths_from_container(v, attr_type, f'{root}[{i}]') for i, v in enumerate(obj)], []) + elif isinstance(obj, dict): + return sum([attr_paths_from_container(v, attr_type, f'{root}[{k}]') for k, v in obj.items()], []) + elif isinstance(obj, Tensor): + raise RuntimeError("Tensor found in container. This should have been set to None by disassemble_tree()") + elif isinstance(obj, PhiTreeNode): + attributes = attr_type(obj) + return sum([attr_paths_from_container(getattr(obj, k), attr_type, f'{root}.{k}') for k in attributes], []) + else: + return [] + + def cached(t: TensorOrTree) -> TensorOrTree: from ._sparse import SparseCoordinateTensor, CompressedSparseMatrix, CompactSparseTensor assert isinstance(t, (Tensor, PhiTreeNode)), f"All arguments must be Tensors but got {type(t)}" @@ -2802,17 +2853,57 @@ def specs_equal(spec1, spec2): def save_tree(file: str, obj): - tree, tensors = disassemble_tree(obj, cache=False, attr_type=all_attributes) - natives, _, specs = disassemble_tensors(tensors, expand=False) - np_natives = [choose_backend(n).numpy(n) for n in natives] - np.savez(file, *np_natives, tree=tree, specs=specs) + tree, tensors = disassemble_tree(obj, False, all_attributes) + paths = attr_paths(obj, all_attributes, 'root') + assert len(paths) == len(tensors) + natives = [t._natives() for t in tensors] + specs = [serialize_spec(t._spec_dict()) for t in tensors] + native_paths = [[f'{p}:{i}' for i in range(len(ns))] for p, ns in zip(paths, natives)] + all_natives = sum(natives, ()) + all_paths = sum(native_paths, []) + all_np = [choose_backend(n).numpy(n) for n in all_natives] + np.savez(file, tree=tree, specs=specs, paths=paths, **{p: n for p, n in zip(all_paths, all_np)}) def load_tree(file: str): data = np.load(file, allow_pickle=True) - np_natives = [data[k] for k in data if k not in ['tree', 'specs']] - specs = data['specs'].tolist() - tensors = assemble_tensors(np_natives, specs) + all_np = {k: data[k] for k in data if k not in ['tree', 'specs', 'paths']} + specs = [unserialize_spec(spec) for spec in data['specs'].tolist()] + tensors = assemble_tensors(list(all_np.values()), specs) tree = data['tree'].tolist() # this may require outside classes via pickle - obj = assemble_tree(tree, tensors, attr_type=all_attributes) - return obj + stored_paths = data['paths'].tolist() + new_paths = attr_paths_from_container(tree, all_attributes, 'root') + if tuple(stored_paths) != tuple(new_paths): + lookup = {path: t for path, t in zip(stored_paths, tensors)} + tensors = [lookup[p] for p in new_paths] + return assemble_tree(tree, tensors, attr_type=all_attributes) + + +def serialize_spec(spec: dict): + from ._sparse import SparseCoordinateTensor, CompactSparseTensor, CompressedSparseMatrix + type_names = {NativeTensor: 'dense', TensorStack: 'stack', CompressedSparseMatrix: 'compressed', SparseCoordinateTensor: 'coo', CompactSparseTensor: 'compact'} + result = {} + for k, v in spec.items(): + if k == 'type': + result[k] = type_names[v] + elif isinstance(v, dict): + result[k] = serialize_spec(v) + else: + assert not isinstance(v, type) + result[k] = v + return result + + +def unserialize_spec(spec: dict): + from ._sparse import SparseCoordinateTensor, CompactSparseTensor, CompressedSparseMatrix + type_names = {NativeTensor: 'dense', TensorStack: 'stack', CompressedSparseMatrix: 'compressed', SparseCoordinateTensor: 'coo', CompactSparseTensor: 'compact'} + lookup = {v: k for k, v in type_names.items()} + result = {} + for k, v in spec.items(): + if k == 'type': + result[k] = lookup[v] + elif isinstance(v, dict): + result[k] = unserialize_spec(v) + else: + result[k] = v + return result