Skip to content

Commit

Permalink
Save/load trees with Layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 27, 2024
1 parent 16526c0 commit e543bb8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,9 @@ def disassemble_tree(obj: PhiTreeNodeType, cache: bool, attr_type=variable_attri
"""
if obj is None:
return MISSING_TENSOR, []
elif isinstance(obj, Layout):
keys, values = disassemble_tree(obj._obj, cache, attr_type)
return {'__layout__': 1, 'stack_dim': obj._stack_dim._to_dict(False), 'obj': keys}, values
elif isinstance(obj, Tensor):
return None, [cached(obj) if cache else obj]
elif isinstance(obj, (tuple, list)):
Expand Down Expand Up @@ -1951,6 +1954,9 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable
return [assemble_tree(item, values, attr_type) for item in obj]
elif isinstance(obj, tuple):
return tuple([assemble_tree(item, values, attr_type) for item in obj])
elif isinstance(obj, dict) and '__layout__' in obj:
content = assemble_tree(obj['obj'], values, attr_type)
return Layout(content, Shape._from_dict(obj['stack_dim']))
elif isinstance(obj, dict):
return {name: assemble_tree(val, values, attr_type) for name, val in obj.items()}
elif isinstance(obj, Tensor):
Expand All @@ -1966,6 +1972,8 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable
def attr_paths(obj: PhiTreeNodeType, attr_type: Callable, root: str) -> List[str]:
if obj is None:
return []
elif isinstance(obj, Layout):
return attr_paths(obj._obj, attr_type, f'{root}._obj')
elif isinstance(obj, Tensor):
return [root]
elif isinstance(obj, (tuple, list)):
Expand Down Expand Up @@ -2003,6 +2011,8 @@ def attr_paths_from_container(obj: PhiTreeNodeType, attr_type: Callable, root: s
return [root]
elif isinstance(obj, (tuple, list)):
return sum([attr_paths_from_container(v, attr_type, f'{root}[{i}]') for i, v in enumerate(obj)], [])
elif isinstance(obj, dict) and '__layout__' in obj:
return attr_paths_from_container(obj['obj'], attr_type, f'{root}._obj')
elif isinstance(obj, dict):
return sum([attr_paths_from_container(v, attr_type, f'{root}[{k}]') for k, v in obj.items()], [])
elif isinstance(obj, Tensor):
Expand Down

0 comments on commit e543bb8

Please sign in to comment.