From 999695de7ee49e1ac85107b8235604036916163c Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Sun, 19 Nov 2023 23:42:28 -0800 Subject: [PATCH] MPL: `ImpactXParticleContainer.plot_phasespace()` Add an interactive plotter to quickly check the current phase space of the particle bunch. This is heavily inspired by Wake-T's `bunch.show()` functionality. --- .../impactx/ImpactXParticleContainer.py | 239 ++++++++++++++++++ tests/python/test_dataframe.py | 6 + 2 files changed, 245 insertions(+) diff --git a/src/python/impactx/ImpactXParticleContainer.py b/src/python/impactx/ImpactXParticleContainer.py index 8547348db..b1a8f3918 100644 --- a/src/python/impactx/ImpactXParticleContainer.py +++ b/src/python/impactx/ImpactXParticleContainer.py @@ -47,8 +47,247 @@ def ix_pc_to_df(self, local=True, comm=None, root_rank=0): return df +def ix_pc_plot_mpl_phasespace(self, num_bins=50, root_rank=0): + """ + Plot the longitudinal and transverse phase space projections with matplotlib. + + Parameters + ---------- + self : ImpactXParticleContainer_* + The particle container class in ImpactX + num_bins : int, default=50 + The number of bins for spatial and momentum directions per plot axis. + root_rank : int + MPI root rank to reduce to + + Returns + ------- + A matplotlib figure with containing the plot. + """ + import matplotlib.pyplot as plt + import numpy as np + + # Matplotlib canvas: figure and plottable axes areas + fig, axes = plt.subplots(1, 3, figsize=(16, 4), constrained_layout=True) + (ax_xpx, ax_ypy, ax_tpt) = axes + + # projected axes + ax_x, ax_px = ax_xpx.twinx(), ax_xpx.twiny() + ax_y, ax_py = ax_ypy.twinx(), ax_ypy.twiny() + ax_t, ax_pt = ax_tpt.twinx(), ax_tpt.twiny() + + # Beam Characteristics + rbc = self.reduced_beam_characteristics() + + # Data Histogramming + df = self.to_df(local=True) + + # update for plot unit system + # TODO: normalize to t/z to um and mc depending on s or t + m2u = 1.0e6 + df.position_x = df.position_x.multiply(m2u) + df.position_y = df.position_y.multiply(m2u) + df.position_t = df.position_t.multiply(m2u) + + xpx, x_edges, px_edges = np.histogram2d( + df["position_x"], + df["momentum_x"], + bins=num_bins, + range=[ + [rbc["x_min"] * m2u, rbc["x_max"] * m2u], + [rbc["px_min"], rbc["px_max"]], + ], + ) + + ypy, y_edges, py_edges = np.histogram2d( + df["position_y"], + df["momentum_y"], + bins=num_bins, + range=[ + [rbc["y_min"] * m2u, rbc["y_max"] * m2u], + [rbc["py_min"], rbc["py_max"]], + ], + ) + + tpt, t_edges, pt_edges = np.histogram2d( + df["position_t"], + df["momentum_t"], + bins=num_bins, + range=[ + [rbc["t_min"] * m2u, rbc["t_max"] * m2u], + [rbc["pt_min"], rbc["pt_max"]], + ], + ) + + # MPI reduce + # nothing to do for non-MPI runs + from inspect import getmodule + + ix = getmodule(self) + if ix.Config.have_mpi: + from mpi4py import MPI + + comm = MPI.COMM_WORLD # TODO: get currently used ImpactX communicator here + rank = comm.Get_rank() + + # MPI_Reduce the node-local histogram data + combined_data = np.concatenate([xpx, ypy, tpt]) + comm.Reduce( + MPI.IN_PLACE, + combined_data, + op=MPI.SUM, + root=root_rank, + ) + + if rank != root_rank: + return + + [xpx, ypy, tpt] = np.split( + combined_data, + [ + len(xpx), + len(xpx) + len(ypy), + # len(xpx) + len(ypy) + len(tpt) + ], + ) + + # histograms per axis + x = np.sum(xpx, axis=1) + px = np.sum(xpx, axis=0) + y = np.sum(ypy, axis=1) + py = np.sum(ypy, axis=0) + t = np.sum(tpt, axis=1) + pt = np.sum(tpt, axis=0) + + # Plotting + def plot_2d(hist, r, p, r_edges, p_edges, ax_r, ax_p, ax_rp): + hist = np.ma.masked_where(hist == 0, hist) + im = ax_rp.imshow( + hist.T, + origin="lower", + aspect="auto", + extent=[r_edges[0], r_edges[-1], p_edges[0], p_edges[-1]], + ) + cbar = fig.colorbar(im, ax=ax_rp) + + r_mids = (r_edges[:-1] + r_edges[1:]) / 2 + p_mids = (p_edges[:-1] + p_edges[1:]) / 2 + ax_r.plot(r_mids, r, c="w", lw=0.8, alpha=0.7) + ax_r.plot(r_mids, r, c="k", lw=0.5, alpha=0.7) + ax_r.fill_between(r_mids, r, facecolor="k", alpha=0.2) + ax_p.plot(p, p_mids, c="w", lw=0.8, alpha=0.7) + ax_p.plot(p, p_mids, c="k", lw=0.5, alpha=0.7) + ax_p.fill_betweenx(p_mids, p, facecolor="k", alpha=0.2) + + return cbar + + cbar_xpx = plot_2d(xpx, x, px, x_edges, px_edges, ax_x, ax_px, ax_xpx) + cbar_ypy = plot_2d(ypy, y, py, y_edges, py_edges, ax_y, ax_py, ax_ypy) + cbar_tpt = plot_2d(tpt, t, pt, t_edges, pt_edges, ax_t, ax_pt, ax_tpt) + + # Limits + def set_limits(r, p, r_edges, p_edges, ax_r, ax_p, ax_rp): + pad = 0.1 + len_r = r_edges[-1] - r_edges[0] + len_p = p_edges[-1] - p_edges[0] + ax_rp.set_xlim(r_edges[0] - len_r * pad, r_edges[-1] + len_r * pad) + ax_rp.set_ylim(p_edges[0] - len_p * pad, p_edges[-1] + len_p * pad) + + # ensure zoom does not change value axis for projections + def on_xlims_change(axes): + if not axes.xlim_reset_in_progress: + pad = 6.0 + axes.xlim_reset_in_progress = True + axes.set_xlim(0, np.max(p) * pad) + axes.xlim_reset_in_progress = False + + ax_p.xlim_reset_in_progress = False + ax_p.callbacks.connect("xlim_changed", on_xlims_change) + on_xlims_change(ax_p) + + def on_ylims_change(axes): + if not axes.ylim_reset_in_progress: + pad = 6.0 + axes.ylim_reset_in_progress = True + axes.set_ylim(0, np.max(r) * pad) + axes.ylim_reset_in_progress = False + + ax_r.ylim_reset_in_progress = False + ax_r.callbacks.connect("ylim_changed", on_ylims_change) + on_ylims_change(ax_r) + + set_limits(x, px, x_edges, px_edges, ax_x, ax_px, ax_xpx) + set_limits(y, py, y_edges, py_edges, ax_y, ax_py, ax_ypy) + set_limits(t, pt, t_edges, pt_edges, ax_t, ax_pt, ax_tpt) + + # Annotations + fig.canvas.manager.set_window_title("Phase Space") + ax_xpx.set_xlabel(r"$\Delta x$ [$\mu$m]") + ax_xpx.set_ylabel(r"$\Delta p_x$ [mc]") + cbar_xpx.set_label(r"$Q$ [C/bin]") + # ax_x.patch.set_alpha(0) + ax_x.set_yticks([]) + ax_px.set_xticks([]) + + ax_ypy.set_xlabel(r"$\Delta y$ [$\mu$m]") + ax_ypy.set_ylabel(r"$\Delta p_y$ [mc]") + cbar_ypy.set_label(r"$Q$ [C/bin]") + ax_y.set_yticks([]) + ax_py.set_xticks([]) + + # TODO: update depending on s or t + ax_tpt.set_xlabel(r"$\Delta ct$ [$\mu$m]") + ax_tpt.set_ylabel(r"$\Delta p_t$ [mc]") + cbar_tpt.set_label(r"$Q$ [C/bin]") + ax_t.set_yticks([]) + ax_pt.set_xticks([]) + + # TODO: write an auto-formatter that picks m, mu, n, p, f, k, M, G + # automatically + ax_xpx.legend( + title=r"$\epsilon_{n,x}=$" + f"{rbc['emittance_x']*1e6:.3f} µm" + "\n" + rf"$\sigma_x=${rbc['sig_x']*1e6:.3f} µm" + "\n" + rf"$\beta_x=${rbc['beta_x']*1e3:.3f} mm" + "\n" + rf"$\alpha_x=${rbc['alpha_x']:.3f}", + loc="upper right", + framealpha=0.8, + handles=[], + ) + ax_ypy.legend( + title=r"$\epsilon_{n,y}=$" + f"{rbc['emittance_y']*1e6:.3f} µm" + "\n" + rf"$\sigma_x=${rbc['sig_y']*1e6:.3f} µm" + "\n" + rf"$\beta_x=${rbc['beta_y']*1e3:.3f} mm" + "\n" + rf"$\alpha_x=${rbc['alpha_y']:.3f}", + loc="upper right", + framealpha=0.8, + handles=[], + ) + ax_tpt.legend( + title=r"$\epsilon_{n,t}=$" + f"{rbc['emittance_t']*1e6:.3f}" + r" MeV$\cdot$s" + "\n" + rf"$\sigma_t=${rbc['sig_t']*1e15:.3f} fs", + # TODO: sigma_pz, I_peak, t_FWHM, ... + loc="upper right", + framealpha=0.8, + handles=[], + ) + + return fig + + def register_ImpactXParticleContainer_extension(ixpc): """ImpactXParticleContainer helper methods""" # register member functions for ImpactXParticleContainer ixpc.to_df = ix_pc_to_df + ixpc.plot_phasespace = ix_pc_plot_mpl_phasespace diff --git a/tests/python/test_dataframe.py b/tests/python/test_dataframe.py index a209317b4..24fea522a 100644 --- a/tests/python/test_dataframe.py +++ b/tests/python/test_dataframe.py @@ -64,6 +64,12 @@ def test_df_pandas(): if df is not None: assert npart == len(df) + # plot + fig = pc.plot_phasespace() + import matplotlib.pyplot as plt + + plt.show() + if __name__ == "__main__": test_df_pandas()