Skip to content

Commit

Permalink
[vis] Fix automatic colors for connected point clouds
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 31, 2024
1 parent 526cbc3 commit 1229800
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _get_range(bounds: Box, index: int):

def _next_line_color(axes: Axes, kind: str = None, get_index=False):
kind = ['patches', 'lines', 'collections', 'containers'] if kind is None else kind.split(',')
next_index = max([len(getattr(axes, a)) for a in kind])
next_index = max([len(getattr(axes, k)) for k in kind])
if get_index:
return next_index
return _default_color(next_index)
Expand Down Expand Up @@ -573,6 +573,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val

@staticmethod
def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Tensor, alpha: Tensor, err: Tensor, min_val, max_val, label):
connected = spatial(data.points)
if isinstance(data.sampled_elements, GeometryStack):
for idx in data.sampled_elements.object_dims.meshgrid():
PointCloud2D._plot_points(axis, data[idx], dims, vector, color[idx], alpha[idx], err[idx], min_val, max_val, label)
Expand All @@ -581,10 +582,13 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten
x, y = reshaped_numpy(data.points.vector[dims], [vector, non_channel(data)])
if (color == None).all:
if not math.is_finite(data.values).any:
mpl_colors = [_next_line_color(axis)] * non_channel(data).volume
mpl_colors = [_next_line_color(axis, 'lines' if connected else 'collections')] * non_channel(data).volume
else:
values = reshaped_numpy(data.values, [non_channel(data)])
mpl_colors = add_color_bar(axis, values, min_val, max_val)
if np.any(values != values[0]):
mpl_colors = add_color_bar(axis, values, min_val, max_val)
else:
mpl_colors = [_next_line_color(axis, 'lines' if connected else 'collections')] * non_channel(data).volume
else:
mpl_colors = matplotlib_colors(color, non_channel(data), default=0)
alphas = reshaped_numpy(alpha, [non_channel(data)])
Expand Down Expand Up @@ -635,13 +639,13 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten
rad = reshaped_numpy(data.geometry.bounding_radius(), [data.shape.non_channel])
shapes = [plt.Circle((xi, yi), radius=ri, linewidth=0, alpha=a, facecolor=ci) for xi, yi, ri, ci, a in zip(x, y, rad, mpl_colors, alphas)]
axis.add_collection(matplotlib.collections.PatchCollection(shapes, match_original=True))
if spatial(data.points): # Connect by line
if connected: # Connect by line
for i, idx in enumerate(instance(data).meshgrid()):
for sp_dim in spatial(data):
other_sp = spatial(data).without(sp_dim)
xs, ys = reshaped_numpy(data[idx].points.vector[dims], [vector, sp_dim, other_sp])
if (color == None).all:
col = _next_line_color(axis)
col = _next_line_color(axis, 'lines' if connected else 'collections')
else:
col = _plt_col(color)
alpha_f = float(alpha[idx].max)
Expand Down

0 comments on commit 1229800

Please sign in to comment.