Skip to content

Commit

Permalink
[vis] Fix stream plots, only use for non-zero fields
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 7, 2023
1 parent 6a3ea32 commit 6a97bc0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val
class StreamPlot2D(Recipe):

def can_plot(self, data: Field, space: Box) -> bool:
return data.spatial_rank == 2 and 'vector' in channel(data) and data.is_grid
return data.spatial_rank == 2 and 'vector' in channel(data) and data.is_grid and (data.values != 0).any

def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val: float, show_color_bar: bool, color: Tensor, alpha: Tensor, err: Tensor):
vector = data.geometry.shape['vector']
Expand All @@ -441,14 +441,14 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val
y = y[0, :]
u, v = reshaped_numpy(data.values, [vector, *data.shape.without('vector')])
if (color == None).all:
col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')])
col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')]).T
else:
if color.shape:
col = [_plt_col(c) for c in color.numpy(data.shape.non_channel).reshape(-1)]
else:
col = _plt_col(color)
# alphas = reshaped_numpy(alpha, [data.shape.without('vector')]) alpha not supported yet
subplot.streamplot(x, y, u, v, color=col, cmap=plt.cm.get_cmap('viridis'))
subplot.streamplot(x, y, u.T, v.T, color=col, cmap=plt.cm.get_cmap('viridis'))


class EmbeddedPoint2D(Recipe):
Expand Down

0 comments on commit 6a97bc0

Please sign in to comment.