Skip to content

Commit

Permalink
[vis] Fix nested values for color, alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 6, 2024
1 parent e3efbff commit c3b61ab
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions phi/vis/_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,16 @@ def plot_frame(frame: int):


def layout_pytree_node(data, wrap_leaf=False):
if isinstance(data, tuple):
return layout(tuple([layout_pytree_node(i) for i in data]), batch('tuple'))
elif isinstance(data, list):
return layout([layout_pytree_node(i) for i in data], batch('list'))
elif isinstance(data, dict):
return layout({k: layout_pytree_node(v) for k, v in data.items()}, batch('dict'))
return wrap(data) if wrap_leaf else data


def layout_one_level(data, wrap_leaf=False):
if isinstance(data, tuple):
return layout(data, batch('tuple'))
elif isinstance(data, list):
Expand All @@ -418,7 +428,7 @@ def layout_pytree_node(data, wrap_leaf=False):
return wrap(data) if wrap_leaf else data


def layout_sub_figures(data: Union[Tensor, Field],
def layout_sub_figures(any_data: Union[Tensor, Field],
row_dims: DimFilter,
col_dims: DimFilter,
animate: DimFilter, # do not reduce these dims, has priority
Expand All @@ -428,9 +438,9 @@ def layout_sub_figures(data: Union[Tensor, Field],
positioning: Dict[Tuple[int, int], List],
indices: Dict[Tuple[int, int], List[dict]],
base_index: Dict[str, Union[int, str]]) -> Tuple[int, int, Shape, Shape]: # rows, cols
if data is None:
raise ValueError(f"Cannot layout figure for '{data}'")
data = layout_pytree_node(data, wrap_leaf=False)
if any_data is None:
raise ValueError(f"Cannot layout figure for '{any_data}'")
data = layout_one_level(any_data, wrap_leaf=False)
if isinstance(data, Tensor) and data.dtype.kind == object: # layout
rows, cols = 0, 0
non_reduced = math.EMPTY_SHAPE
Expand Down

0 comments on commit c3b61ab

Please sign in to comment.