diff --git a/phi/vis/_vis.py b/phi/vis/_vis.py index 7b83122ca..6eef12e06 100644 --- a/phi/vis/_vis.py +++ b/phi/vis/_vis.py @@ -409,6 +409,16 @@ def plot_frame(frame: int): def layout_pytree_node(data, wrap_leaf=False): + if isinstance(data, tuple): + return layout(tuple([layout_pytree_node(i) for i in data]), batch('tuple')) + elif isinstance(data, list): + return layout([layout_pytree_node(i) for i in data], batch('list')) + elif isinstance(data, dict): + return layout({k: layout_pytree_node(v) for k, v in data.items()}, batch('dict')) + return wrap(data) if wrap_leaf else data + + +def layout_one_level(data, wrap_leaf=False): if isinstance(data, tuple): return layout(data, batch('tuple')) elif isinstance(data, list): @@ -418,7 +428,7 @@ def layout_pytree_node(data, wrap_leaf=False): return wrap(data) if wrap_leaf else data -def layout_sub_figures(data: Union[Tensor, Field], +def layout_sub_figures(any_data: Union[Tensor, Field], row_dims: DimFilter, col_dims: DimFilter, animate: DimFilter, # do not reduce these dims, has priority @@ -428,9 +438,9 @@ def layout_sub_figures(data: Union[Tensor, Field], positioning: Dict[Tuple[int, int], List], indices: Dict[Tuple[int, int], List[dict]], base_index: Dict[str, Union[int, str]]) -> Tuple[int, int, Shape, Shape]: # rows, cols - if data is None: - raise ValueError(f"Cannot layout figure for '{data}'") - data = layout_pytree_node(data, wrap_leaf=False) + if any_data is None: + raise ValueError(f"Cannot layout figure for '{any_data}'") + data = layout_one_level(any_data, wrap_leaf=False) if isinstance(data, Tensor) and data.dtype.kind == object: # layout rows, cols = 0, 0 non_reduced = math.EMPTY_SHAPE