Skip to content

Commit

Permalink
save/load now store paths and allow for property order change
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 20, 2024
1 parent 2725e4e commit da6eb3a
Showing 1 changed file with 100 additions and 9 deletions.
109 changes: 100 additions & 9 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,57 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable
return obj


def attr_paths(obj: PhiTreeNodeType, attr_type: Callable, root: str) -> List[str]:
if obj is None:
return []
elif isinstance(obj, Tensor):
return [root]
elif isinstance(obj, (tuple, list)):
paths = []
for i, item in enumerate(obj):
path = attr_paths(item, attr_type, f'{root}[{i}]')
paths.extend(path)
return paths
elif isinstance(obj, dict):
paths = []
for name, item in obj.items():
path = attr_paths(item, attr_type, f'{root}[{name}]')
paths.extend(path)
return paths
elif isinstance(obj, PhiTreeNode):
attributes = attr_type(obj)
paths = []
for attr in attributes:
path = attr_paths(getattr(obj, attr), attr_type, f'{root}.{attr}')
paths.extend(path)
return paths
else: # native tensor?
try:
return [] if choose_backend(obj) == OBJECTS else [root]
except NoBackendFound:
return []


def attr_paths_from_container(obj: PhiTreeNodeType, attr_type: Callable, root: str) -> List[str]:
if isinstance(obj, str) and obj == MISSING_TENSOR:
return []
elif isinstance(obj, str) and obj == NATIVE_TENSOR:
return [root]
elif obj is None:
return [root]
elif isinstance(obj, (tuple, list)):
return sum([attr_paths_from_container(v, attr_type, f'{root}[{i}]') for i, v in enumerate(obj)], [])
elif isinstance(obj, dict):
return sum([attr_paths_from_container(v, attr_type, f'{root}[{k}]') for k, v in obj.items()], [])
elif isinstance(obj, Tensor):
raise RuntimeError("Tensor found in container. This should have been set to None by disassemble_tree()")
elif isinstance(obj, PhiTreeNode):
attributes = attr_type(obj)
return sum([attr_paths_from_container(getattr(obj, k), attr_type, f'{root}.{k}') for k in attributes], [])
else:
return []


def cached(t: TensorOrTree) -> TensorOrTree:
from ._sparse import SparseCoordinateTensor, CompressedSparseMatrix, CompactSparseTensor
assert isinstance(t, (Tensor, PhiTreeNode)), f"All arguments must be Tensors but got {type(t)}"
Expand Down Expand Up @@ -2802,17 +2853,57 @@ def specs_equal(spec1, spec2):


def save_tree(file: str, obj):
tree, tensors = disassemble_tree(obj, cache=False, attr_type=all_attributes)
natives, _, specs = disassemble_tensors(tensors, expand=False)
np_natives = [choose_backend(n).numpy(n) for n in natives]
np.savez(file, *np_natives, tree=tree, specs=specs)
tree, tensors = disassemble_tree(obj, False, all_attributes)
paths = attr_paths(obj, all_attributes, 'root')
assert len(paths) == len(tensors)
natives = [t._natives() for t in tensors]
specs = [serialize_spec(t._spec_dict()) for t in tensors]
native_paths = [[f'{p}:{i}' for i in range(len(ns))] for p, ns in zip(paths, natives)]
all_natives = sum(natives, ())
all_paths = sum(native_paths, [])
all_np = [choose_backend(n).numpy(n) for n in all_natives]
np.savez(file, tree=tree, specs=specs, paths=paths, **{p: n for p, n in zip(all_paths, all_np)})


def load_tree(file: str):
data = np.load(file, allow_pickle=True)
np_natives = [data[k] for k in data if k not in ['tree', 'specs']]
specs = data['specs'].tolist()
tensors = assemble_tensors(np_natives, specs)
all_np = {k: data[k] for k in data if k not in ['tree', 'specs', 'paths']}
specs = [unserialize_spec(spec) for spec in data['specs'].tolist()]
tensors = assemble_tensors(list(all_np.values()), specs)
tree = data['tree'].tolist() # this may require outside classes via pickle
obj = assemble_tree(tree, tensors, attr_type=all_attributes)
return obj
stored_paths = data['paths'].tolist()
new_paths = attr_paths_from_container(tree, all_attributes, 'root')
if tuple(stored_paths) != tuple(new_paths):
lookup = {path: t for path, t in zip(stored_paths, tensors)}
tensors = [lookup[p] for p in new_paths]
return assemble_tree(tree, tensors, attr_type=all_attributes)


def serialize_spec(spec: dict):
from ._sparse import SparseCoordinateTensor, CompactSparseTensor, CompressedSparseMatrix
type_names = {NativeTensor: 'dense', TensorStack: 'stack', CompressedSparseMatrix: 'compressed', SparseCoordinateTensor: 'coo', CompactSparseTensor: 'compact'}
result = {}
for k, v in spec.items():
if k == 'type':
result[k] = type_names[v]
elif isinstance(v, dict):
result[k] = serialize_spec(v)
else:
assert not isinstance(v, type)
result[k] = v
return result


def unserialize_spec(spec: dict):
from ._sparse import SparseCoordinateTensor, CompactSparseTensor, CompressedSparseMatrix
type_names = {NativeTensor: 'dense', TensorStack: 'stack', CompressedSparseMatrix: 'compressed', SparseCoordinateTensor: 'coo', CompactSparseTensor: 'compact'}
lookup = {v: k for k, v in type_names.items()}
result = {}
for k, v in spec.items():
if k == 'type':
result[k] = lookup[v]
elif isinstance(v, dict):
result[k] = unserialize_spec(v)
else:
result[k] = v
return result

0 comments on commit da6eb3a

Please sign in to comment.