From 6a97bc07906becaca7c27d5b2bda3290905f27f8 Mon Sep 17 00:00:00 2001 From: holl- Date: Tue, 7 Nov 2023 16:43:01 +0100 Subject: [PATCH] [vis] Fix stream plots, only use for non-zero fields --- phi/vis/_matplotlib/_matplotlib_plots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phi/vis/_matplotlib/_matplotlib_plots.py b/phi/vis/_matplotlib/_matplotlib_plots.py index e9065e4fb..098174793 100644 --- a/phi/vis/_matplotlib/_matplotlib_plots.py +++ b/phi/vis/_matplotlib/_matplotlib_plots.py @@ -431,7 +431,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val class StreamPlot2D(Recipe): def can_plot(self, data: Field, space: Box) -> bool: - return data.spatial_rank == 2 and 'vector' in channel(data) and data.is_grid + return data.spatial_rank == 2 and 'vector' in channel(data) and data.is_grid and (data.values != 0).any def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val: float, show_color_bar: bool, color: Tensor, alpha: Tensor, err: Tensor): vector = data.geometry.shape['vector'] @@ -441,14 +441,14 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val y = y[0, :] u, v = reshaped_numpy(data.values, [vector, *data.shape.without('vector')]) if (color == None).all: - col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')]) + col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')]).T else: if color.shape: col = [_plt_col(c) for c in color.numpy(data.shape.non_channel).reshape(-1)] else: col = _plt_col(color) # alphas = reshaped_numpy(alpha, [data.shape.without('vector')]) alpha not supported yet - subplot.streamplot(x, y, u, v, color=col, cmap=plt.cm.get_cmap('viridis')) + subplot.streamplot(x, y, u.T, v.T, color=col, cmap=plt.cm.get_cmap('viridis')) class EmbeddedPoint2D(Recipe):