Skip to content

Commit

Permalink
Make pint an optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
eltos committed Jan 21, 2025
1 parent 34eb72d commit dd57f2e
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 64 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ A plotting library for [Xsuite](https://github.com/xsuite) and simmilar accelera
## Usage

```bash
pip install xplt
pip install xplt[full]
```

Read the docs at https://xsuite.github.io/xplt
Expand Down
6 changes: 5 additions & 1 deletion docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
## Installation

```bash
pip install xplt
pip install xplt[full]
```

It is recommended to install the library with all optional dependencies, to make use its full functionality.
For a minimal installation without `[full]`, certain features like automatic unit conversion and resolving or support for pandas dataframes are disabled.


## Gallery

Click on the plots below to show the respective section in the [User guide](usage):
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies = [
"matplotlib>=3.6",
"numpy>=1.17.0",
"scipy>=1.2.0",
"pint>=0.24.1",
]
dynamic = ["version"]

Expand All @@ -36,7 +35,10 @@ documentation = "https://xsuite.github.io/xplt"
repository = "https://github.com/xsuite/xplt"

[project.optional-dependencies]
all = ["pandas"]
full = [
"pandas",
"pint>=0.24.1",
]


# Build tools
Expand Down
23 changes: 16 additions & 7 deletions xplt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@
import matplotlib.patches
import matplotlib.path
import numpy as np
import pint

from .util import defaults, flattened, defaults_for
from .properties import Property, find_property, DataProperty, arb_unit
from .properties import (
Property,
find_property,
DataProperty,
arb_unit,
_convert_value_to_unit,
_fmt_qty,
_has_pint,
)


class ManifoldMultipleLocator(mpl.ticker.MaxNLocator):
Expand Down Expand Up @@ -328,7 +335,9 @@ def __init__(
self._user_properties[name] = (
DataProperty(name, arg) if isinstance(arg, str) else arg
)
self._display_units = defaults(display_units, s="m", x="mm", y="mm", p="mrad")
self._display_units = defaults(
display_units, **dict(s="m", x="mm", y="mm", p="mrad") if _has_pint() else {}
)

if annotation is None:
annotation = ax is None
Expand Down Expand Up @@ -542,7 +551,7 @@ def factor_for(self, p):
Returns:
float: Factor to convert parameter into display unit
"""
return pint.Quantity(1, self.prop(p).unit).to(self.display_unit_for(p)).m
return _convert_value_to_unit(1, self.prop(p).unit, self.display_unit_for(p))

def display_unit_for(self, p):
"""Return display unit for parameter
Expand Down Expand Up @@ -680,9 +689,9 @@ def label_for(self, *pp, unit=True, description=True):
if units[0] == arb_unit: # arbitrary unit
append = " / " + arb_unit
else:
display_unit = pint.Unit(units[0]) # all have the same unit (see above)
if display_unit != pint.Unit("1"):
append = f" / ${display_unit:~X}$" # see "NIST Guide to the SI"
display_unit = units[0] # all have the same unit (see above)
if display_unit and display_unit != "1":
append = f" / ${_fmt_qty(None, display_unit)}$" # see "NIST Guide to the SI"
if append:
# heuristic: if labels contain expressions with +, - or ± then add parentheses
if re.findall(r"([-+±]|\\pm)(?![^(]*\))(?![^{]*\})", label.split(" ")[-1]):
Expand Down
32 changes: 17 additions & 15 deletions xplt/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


from .util import *
from .properties import _fmt_qty, _convert_value_to_unit, _has_pint
from .base import XManifoldPlot
from .properties import Property, DerivedProperty, find_property

Expand Down Expand Up @@ -79,14 +80,15 @@ def _init_particle_mixin(
kwargs["_properties"] = defaults(
kwargs.get("_properties"), **self._derived_particle_properties
)
kwargs["display_units"] = defaults(
kwargs.get("display_units"),
X="mm^(1/2)",
Y="mm^(1/2)",
P="mm^(1/2)",
J="mm", # Action
Θ="rad", # Angle
)
if _has_pint():
kwargs["display_units"] = defaults(
kwargs.get("display_units"),
X="mm^(1/2)",
Y="mm^(1/2)",
P="mm^(1/2)",
J="mm", # Action
Θ="rad", # Angle
)

return kwargs

Expand Down Expand Up @@ -209,10 +211,11 @@ def _init_particle_histogram_mixin(self, **kwargs):
kwargs["_properties"] = defaults(
kwargs.get("_properties"), **self._histogram_particle_properties
)
kwargs["display_units"] = defaults(
kwargs.get("display_units"),
current="nA",
)
if _has_pint():
kwargs["display_units"] = defaults(
kwargs.get("display_units"),
current="nA",
)

return kwargs

Expand Down Expand Up @@ -392,8 +395,7 @@ def update(self, particles, mask=None, *, autoscale=None, dataset_id=None):
hist /= np.sum(hist)

if p in ("rate", "current"):
factor = pint.Quantity(1, prop_x.unit).to("s").m
hist /= factor * np.diff(edges)
hist /= np.diff(edges) * _convert_value_to_unit(1, prop_x.unit, "s")

# post-processing expression wrappers
if wrap := self.on_y_expression[i][j][k]:
Expand Down Expand Up @@ -437,7 +439,7 @@ def update(self, particles, mask=None, *, autoscale=None, dataset_id=None):
dv = np.unique(list(self._actual_bin_width.values()))
if len(dv) == 1:
x = prop_x.symbol.strip("$")
self.annotate(f"$\\Delta {{{x}}}_\\mathrm{{bin}} = {fmt(dv[0], prop_x.unit)}$")
self.annotate(f"$\\Delta {{{x}}}_\\mathrm{{bin}} = {_fmt_qty(dv[0], prop_x.unit)}$")
else:
self.annotate("")

Expand Down
70 changes: 55 additions & 15 deletions xplt/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,62 @@
__date__ = "2023-11-11"


import pint

try:
import pandas as pd
except ImportError:
# pandas is an optional dependency
pd = None
from .util import *


arb_unit = "arb. unit"


def _import_pint_or_raise(info: str = ""):
try:
import pint
except ImportError as e:
raise ImportError(
f"The pint package is required {info}. Install it with 'pip install xplt[full]'."
) from e
return pint


def _has_pint():
try:
_import_pint_or_raise()
return True
except ImportError:
return False


def _convert_value_to_unit(value, unit: str, to_unit: str):
if unit == to_unit:
return value
pint = _import_pint_or_raise(f"to convert units (from '{unit}' to '{to_unit}')")
return value * pint.Quantity(1, unit).to(to_unit).magnitude


def _deduce_derived_unit(function, input_units: dict) -> str:
"""Deduce the output unit of function, given the units for each input parameter"""
pint = _import_pint_or_raise(f"to automatically deduce units for '{function}'")
return function(**{p: pint.Quantity(1, u) for p, u in input_units.items()}).units


def _fmt_qty(t, unit):
"""Human-readable representation of value in unit (latex syntax)
Args:
t (float | None): the value
unit (str | None): the unit
Returns:
str: Formatted string in latex syntax without '$', e.g. "5.27\\ \\mathrm{mm}"
"""
if t is not None:
t = float(f"{t:g}") # to handle corner cases like 9.999999e-07
try:
pint = _import_pint_or_raise(f"to format the value '{t}' in units of '{unit}'")
s = f"{pint.Unit(unit) if t is None else pint.Quantity(t, unit):#~.4gX}"
return s.rstrip("\\").replace(" ", "\\ ", 1)
except ImportError:
return ("" if t is None else f"{t:g}\\ ") + f"\\mathrm{{{unit}}}"


class Property:
def __init__(self, symbol, unit, description=None):
"""Class holding generic property information
Expand All @@ -36,8 +79,6 @@ def __init__(self, symbol, unit, description=None):
self.unit = unit
self.description = description

pint.Unit(self.unit) # to raise an error if not a valid unit

def values(self, data, mask=None, *, unit=None):
"""Get masked data for this property
Expand Down Expand Up @@ -142,8 +183,7 @@ def mask_callback(mask, get):

# convert to unit
if unit is not None:
factor = pint.Quantity(1, self.unit).to(unit).magnitude
v *= factor
v = _convert_value_to_unit(v, self.unit, unit)

return v

Expand Down Expand Up @@ -191,8 +231,7 @@ def values(self, data, mask=None, *, unit=None):

# convert to unit
if unit is not None:
factor = pint.Quantity(1, self.unit).to(unit).magnitude
v *= factor
v = _convert_value_to_unit(v, self.unit, unit)

return v

Expand Down Expand Up @@ -268,8 +307,9 @@ def register_derived_property(name, function, unit=None, symbol=None, descriptio
if unit is None:
# determine unit from function return value
inputs = inspect.signature(function).parameters.keys()
inputs = {p: pint.Quantity(1, find_property(p).unit) for p in inputs}
unit = function(**inputs).units
input_units = {p: find_property(p).unit for p in inputs}
unit = _deduce_derived_unit(function, input_units)

register_property(name, DerivedProperty(symbol or f"${name}$", unit, function, description))


Expand Down
14 changes: 8 additions & 6 deletions xplt/timestructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import warnings
from dataclasses import dataclass
import scipy.signal

from .util import *
from .properties import _fmt_qty
from .base import XManifoldPlot, TwinFunctionLocator, TransformedLocator
from .particles import (
ParticlePlotMixin,
Expand Down Expand Up @@ -864,9 +866,9 @@ def update(
if len(fs) == 1:
if self.relative:
f = fs[0] / self.frev(particles)
self.annotate(f"$f_\\mathrm{{samp}} = {fmt(f, '1')}\\, f_\\mathrm{{rev}}$")
self.annotate(f"$f_\\mathrm{{samp}} = {_fmt_qty(f, '1')}\\, f_\\mathrm{{rev}}$")
else:
self.annotate(f"$f_\\mathrm{{samp}} = {fmt(fs[ 0 ], 'Hz')}$")
self.annotate(f"$f_\\mathrm{{samp}} = {_fmt_qty(fs[ 0], 'Hz')}$")
else:
self.annotate("")

Expand Down Expand Up @@ -1084,7 +1086,7 @@ def update(self, particles, mask=None, *, autoscale=None, dataset_id=None):
times = self._apply_time_range(times)
delay = self.factor_for("t") * np.diff(sorted(times))

self.annotate(f"$\\Delta t_\\mathrm{{bin}} = {fmt(self.bin_time)}$")
self.annotate(f"$\\Delta t_\\mathrm{{bin}} = {_fmt_qty(self.bin_time, 's')}$")

# update plots
changed = []
Expand Down Expand Up @@ -1330,8 +1332,8 @@ def update(

# annotate plot
self.annotate(
f"$\\Delta t_\\mathrm{{count}} = {fmt(timeseries.dt)}$\n"
f"$\\Delta t_\\mathrm{{evaluate}} = {fmt(timeseries.dt * nebins)}$"
f"$\\Delta t_\\mathrm{{count}} = {_fmt_qty(timeseries.dt, 's')}$\n"
f"$\\Delta t_\\mathrm{{evaluate}} = {_fmt_qty(timeseries.dt * nebins, 's')}$"
)

# display units
Expand Down Expand Up @@ -1656,7 +1658,7 @@ def check_insufficient_statistics():
+ (
f"{self.counting_bins_per_evaluation:g}\\,\\Delta t_\\mathrm{{count}}$"
if self.counting_bins_per_evaluation
else f"{fmt(duration)}$"
else f"{_fmt_qty(duration, 's')}$"
)
)

Expand Down
20 changes: 12 additions & 8 deletions xplt/twiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .util import *
from .base import XManifoldPlot
from .line import KnlPlot
from .properties import Property, DataProperty, DerivedProperty
from .properties import Property, DataProperty, DerivedProperty, _has_pint

from matplotlib.collections import PolyCollection
from numpy.testing import assert_equal
Expand Down Expand Up @@ -92,18 +92,22 @@ def __init__(self, twiss=None, kind="bet-dx,x+y", *, line=None, line_kwargs={},
max_y=DataProperty("max_y", "m", "$y_\\mathrm{max}$"),
)

super().__init__(
on_x="s",
on_y=kind,
on_y_subs=subs,
display_units=defaults(
kwargs.pop("display_units", None),
display_units = kwargs.pop("display_units", None)
if _has_pint():
display_units = defaults(
display_units,
bet="m",
d="m",
sigma_="mm",
envelope_="mm",
envelope3_="mm",
),
)

super().__init__(
on_x="s",
on_y=kind,
on_y_subs=subs,
display_units=display_units,
**kwargs,
)

Expand Down
9 changes: 0 additions & 9 deletions xplt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import re

import numpy as np
import pint
import scipy.signal
import matplotlib as mpl
import matplotlib.collections
Expand Down Expand Up @@ -63,14 +62,6 @@ def val(obj):
return obj


def fmt(t, unit="s"):
"""Human-readable representation of value in unit (latex syntax)"""
t = float(f"{t:g}") # to handle corner cases like 9.999999e-07
s = f"{pint.Quantity(t, unit):#~.4gX}".rstrip("\\")
s = s.replace(" ", "\\ ", 1)
return s


#
def ieee_mod(values, m):
"""Return the IEEE remainder (in range -x/2 .. x/2)"""
Expand Down

0 comments on commit dd57f2e

Please sign in to comment.