Skip to content

Commit

Permalink
[vis] Label axes for points with item names
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 5, 2024
1 parent 3359223 commit 6dd13b5
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,23 +631,35 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten
axis.scatter(xs, ys, s=3, marker='o', c=col, alpha=alphas)

if any(non_channel(data).item_names):
PointCloud2D._annotate_points(axis, data.points, color, alpha)
PointCloud2D._annotate_points(axis, data.points, color, alpha, dims)

@staticmethod
def _annotate_points(axis, points: math.Tensor, color: Tensor, alpha: Tensor):
labelled_dims = non_channel(points)
labelled_dims = math.concat_shapes(*[d for d in labelled_dims if d.item_names[0]])
if not labelled_dims:
def _annotate_points(axis, points: math.Tensor, color: Tensor, alpha: Tensor, dims: Tuple[str], label_axis=True, max_axis_labels=10):
labeled_dims = non_channel(points)
labeled_dims = math.concat_shapes(*[d for d in labeled_dims if d.item_names[0]])
if not labeled_dims:
return
if all(dim.name in points.shape.get_item_names('vector') for dim in labelled_dims):
if all(dim.name in points.shape.get_item_names('vector') for dim in labeled_dims):
if label_axis:
for labeled_dim in labeled_dims:
if len(labeled_dim.item_names[0]) <= max_axis_labels:
which_axis = dims.index(labeled_dim.name)
if which_axis == 0:
axis.set_xticks(reshaped_numpy(points.vector[labeled_dim.name], [shape]))
if axis.get_xscale() == 'log':
axis.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
elif which_axis == 1:
axis.set_yticks(reshaped_numpy(points.vector[labeled_dim.name], [shape]))
if axis.get_yscale() == 'log':
axis.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
return # The point labels match one of the figure axes, so they are redundant
if points.shape['vector'].size == 2:
xs, ys = reshaped_numpy(points, ['vector', points.shape.without('vector')])
x_view = axis.get_xlim()[1] - axis.get_xlim()[0]
y_view = axis.get_ylim()[1] - axis.get_ylim()[0]
x_c = .95 * axis.get_xlim()[1] + .1 * axis.get_xlim()[0]
y_c = .95 * axis.get_ylim()[1] + .1 * axis.get_ylim()[0]
for x, y, idx, idx_n in zip(xs, ys, labelled_dims.meshgrid(), labelled_dims.meshgrid(names=True)):
for x, y, idx, idx_n in zip(xs, ys, labeled_dims.meshgrid(), labeled_dims.meshgrid(names=True)):
if axis.get_xscale() == 'log':
offset_x = x * (1 + .0003 * x_view) if x < x_c else x * (1 - .0003 * x_view)
else:
Expand Down

0 comments on commit 6dd13b5

Please sign in to comment.