Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing #2123

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

typing #2123

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions .mypy

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _unpickle_measurement(cls, *args):
return _unpickle(application_registry.Measurement, *args)


def set_application_registry(registry):
def set_application_registry(registry) -> None:
"""Set the application registry, which is used for unpickling operations
and when invoking pint.Quantity or pint.Unit directly.

Expand Down
5 changes: 5 additions & 0 deletions pint/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,10 @@


class Handler(Protocol):
@overload
def __getitem__(self, Never, /) -> Never:
...

@overload
def __getitem__(self, item: type[T]) -> Callable[[T], None]:
...
12 changes: 8 additions & 4 deletions pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class BehaviorChangeWarning(UserWarning):
else:
NUMERIC_TYPES = (Number, Decimal, ndarray, np.number)

def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
def _to_magnitude(
value, force_ndarray: bool = False, force_ndarray_like: bool = False
):
if isinstance(value, (dict, bool)) or value is None:
raise TypeError(f"Invalid magnitude for Quantity: {value!r}")
elif isinstance(value, str) and value == "":
Expand All @@ -117,12 +119,12 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
return np.asarray(value)
return value

def _test_array_function_protocol():
def _test_array_function_protocol() -> bool:
# Test if the __array_function__ protocol is enabled
try:

class FakeArray:
def __array_function__(self, *args, **kwargs):
def __array_function__(self, *args, **kwargs) -> None:
return

np.concatenate([FakeArray()])
Expand All @@ -149,7 +151,9 @@ class np_datetime64:
HAS_NUMPY_ARRAY_FUNCTION = False
NP_NO_VALUE = None

def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
def _to_magnitude(
value, force_ndarray: bool = False, force_ndarray_like: bool = False
):
if force_ndarray or force_ndarray_like:
raise ValueError(
"Cannot force to ndarray or ndarray-like when NumPy is not present."
Expand Down
2 changes: 1 addition & 1 deletion pint/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value

def __init_subclass__(cls, **kwargs: Any):
def __init_subclass__(cls, **kwargs: Any) -> None:
# Get constructor parameters
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
Expand Down
8 changes: 4 additions & 4 deletions pint/delegates/formatter/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FullFormatter(BaseFormatter):

locale: Locale | None = None

def __init__(self, registry: UnitRegistry | None = None):
def __init__(self, registry: UnitRegistry | None = None) -> None:
super().__init__(registry)

self._formatters = {}
Expand Down Expand Up @@ -130,7 +130,7 @@ def format_unit(
unit: PlainUnit | Iterable[tuple[str, Any]],
uspec: str = "",
sort_func: SortFunc | None = None,
**babel_kwds: Unpack[BabelKwds],
**babel_kwds: Unpack[BabelKwds] | Unpack[dict[str, None]],
) -> str:
uspec = uspec or self.default_format
sort_func = sort_func or self.default_sort_func
Expand All @@ -142,7 +142,7 @@ def format_quantity(
self,
quantity: PlainQuantity[MagnitudeT],
spec: str = "",
**babel_kwds: Unpack[BabelKwds],
**babel_kwds: Unpack[BabelKwds] | Unpack[dict[str, None]],
) -> str:
spec = spec or self.default_format
# If Compact is selected, do it at the beginning
Expand Down Expand Up @@ -179,7 +179,7 @@ def format_measurement(
self,
measurement: Measurement,
meas_spec: str = "",
**babel_kwds: Unpack[BabelKwds],
**babel_kwds: Unpack[BabelKwds] | Unpack[dict[str, None]],
) -> str:
meas_spec = meas_spec or self.default_format
# If Compact is selected, do it at the beginning
Expand Down
2 changes: 1 addition & 1 deletion pint/delegates/formatter/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


class BaseFormatter:
def __init__(self, registry: UnitRegistry | None = None):
def __init__(self, registry: UnitRegistry | None = None) -> None:
self._registry = registry


Expand Down
2 changes: 1 addition & 1 deletion pint/delegates/txt_defparser/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError):

msg: str

def __init__(self, msg: str, location: str = ""):
def __init__(self, msg: str, location: str = "") -> None:
self.msg = msg
self.location = location

Expand Down
6 changes: 4 additions & 2 deletions pint/delegates/txt_defparser/defparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class _PintParser(fp.Parser[PintRootBlock, ParserConfig]):

_diskcache: fc.DiskCache | None

def __init__(self, config: ParserConfig, *args: ty.Any, **kwargs: ty.Any):
def __init__(self, config: ParserConfig, *args: ty.Any, **kwargs: ty.Any) -> None:
self._diskcache = kwargs.pop("diskcache", None)
super().__init__(config, *args, **kwargs)

Expand All @@ -68,7 +68,9 @@ class DefParser:
plain.CommentDefinition,
)

def __init__(self, default_config: ParserConfig, diskcache: fc.DiskCache):
def __init__(
self, default_config: ParserConfig, diskcache: fc.DiskCache | None
) -> None:
self._default_config = default_config
self._diskcache = diskcache

Expand Down
26 changes: 13 additions & 13 deletions pint/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class DefinitionError(ValueError, PintError):
definition_type: type
msg: str

def __init__(self, name: str, definition_type: type, msg: str):
def __init__(self, name: str, definition_type: type, msg: str) -> None:
self.name = name
self.definition_type = definition_type
self.msg = msg

def __str__(self):
def __str__(self) -> str:
msg = f"Cannot define '{self.name}' ({self.definition_type}): {self.msg}"
return msg

Expand All @@ -109,10 +109,10 @@ class DefinitionSyntaxError(ValueError, PintError):

msg: str

def __init__(self, msg: str):
def __init__(self, msg: str) -> None:
self.msg = msg

def __str__(self):
def __str__(self) -> str:
return self.msg

def __reduce__(self):
Expand All @@ -125,11 +125,11 @@ class RedefinitionError(ValueError, PintError):
name: str
definition_type: type

def __init__(self, name: str, definition_type: type):
def __init__(self, name: str, definition_type: type) -> None:
self.name = name
self.definition_type = definition_type

def __str__(self):
def __str__(self) -> str:
msg = f"Cannot redefine '{self.name}' ({self.definition_type})"
return msg

Expand All @@ -142,13 +142,13 @@ class UndefinedUnitError(AttributeError, PintError):

unit_names: tuple[str, ...]

def __init__(self, unit_names: str | ty.Iterable[str]):
def __init__(self, unit_names: str | ty.Iterable[str]) -> None:
if isinstance(unit_names, str):
self.unit_names = (unit_names,)
else:
self.unit_names = tuple(unit_names)

def __str__(self):
def __str__(self) -> str:
if len(self.unit_names) == 1:
return f"'{tuple(self.unit_names)[0]}' is not defined in the unit registry"
return f"{tuple(self.unit_names)} are not defined in the unit registry"
Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(
self.dim2 = dim2
self.extra_msg = extra_msg

def __str__(self):
def __str__(self) -> str:
if self.dim1 or self.dim2:
dim1 = f" ({self.dim1})"
dim2 = f" ({self.dim2})"
Expand Down Expand Up @@ -222,7 +222,7 @@ def yield_units(self):
if self.units2:
yield self.units2

def __str__(self):
def __str__(self) -> str:
return (
"Ambiguous operation with offset unit (%s)."
% ", ".join(str(u) for u in self.yield_units())
Expand Down Expand Up @@ -250,7 +250,7 @@ def yield_units(self):
if self.units2:
yield self.units2

def __str__(self):
def __str__(self) -> str:
return (
"Ambiguous operation with logarithmic unit (%s)."
% ", ".join(str(u) for u in self.yield_units())
Expand All @@ -266,7 +266,7 @@ def __reduce__(self):
class UnitStrippedWarning(UserWarning, PintError):
msg: str

def __init__(self, msg: str):
def __init__(self, msg: str) -> None:
self.msg = msg

def __reduce__(self):
Expand All @@ -280,7 +280,7 @@ class UnexpectedScaleInContainer(Exception):
class UndefinedBehavior(UserWarning, PintError):
msg: str

def __init__(self, msg: str):
def __init__(self, msg: str) -> None:
self.msg = msg

def __reduce__(self):
Expand Down
8 changes: 4 additions & 4 deletions pint/facets/context/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def redefine(self, definition: str) -> None:
if isinstance(definition, UnitDefinition):
self._redefine(definition)

def _redefine(self, definition: UnitDefinition):
def _redefine(self, definition: UnitDefinition) -> None:
self.redefinitions.append(definition)

def hashable(
Expand Down Expand Up @@ -274,13 +274,13 @@ class ContextChain(ChainMap[SrcDst, Context]):
to transform from one dimension to another.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.contexts: list[Context] = []
self.maps.clear() # Remove default empty map
self._graph: dict[SrcDst, set[UnitsContainer]] | None = None

def insert_contexts(self, *contexts: Context):
def insert_contexts(self, *contexts: Context) -> None:
"""Insert one or more contexts in reversed order the chained map.
(A rule in last context will take precedence)

Expand All @@ -292,7 +292,7 @@ def insert_contexts(self, *contexts: Context):
self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps
self._graph = None

def remove_contexts(self, n: int | None = None):
def remove_contexts(self, n: int | None = None) -> None:
"""Remove the last n inserted contexts from the chain.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion pint/facets/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def persist(self, **kwargs):
return result

@check_dask_array
def visualize(self, **kwargs):
def visualize(self, **kwargs) -> None:
"""Produce a visual representation of the Dask graph.

The graphviz library is required.
Expand Down
2 changes: 1 addition & 1 deletion pint/facets/group/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Group(SharedRegistryObject):
If not given, a root Group will be created.
"""

def __init__(self, name: str):
def __init__(self, name: str) -> None:
# The name of the group.
self.name = name

Expand Down
4 changes: 2 additions & 2 deletions pint/facets/group/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class GenericGroupRegistry(
# to enjoy typing goodies
Group = type[objects.Group]

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
#: Map group name to group.
self._groups: dict[str, objects.Group] = {}
Expand Down Expand Up @@ -81,7 +81,7 @@ def _register_definition_adders(self) -> None:
super()._register_definition_adders()
self._register_adder(GroupDefinition, self._add_group)

def _add_unit(self, definition: UnitDefinition):
def _add_unit(self, definition: UnitDefinition) -> None:
super()._add_unit(definition)
# TODO: delta units are missing
self.get_group("root").add_units(definition.name)
Expand Down
8 changes: 4 additions & 4 deletions pint/facets/measurement/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class MeasurementQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Measurement support
def plus_minus(self, error, relative=False):
def plus_minus(self, error, relative: bool = False):
if isinstance(error, self.__class__):
if relative:
raise ValueError(f"{error} is not a valid relative error.")
Expand Down Expand Up @@ -97,15 +97,15 @@ def __reduce__(self):

return _unpickle_measurement, (Measurement, self.magnitude, self._units)

def __repr__(self):
def __repr__(self) -> str:
return "<Measurement({}, {}, {})>".format(
self.magnitude.nominal_value, self.magnitude.std_dev, self.units
)

def __str__(self):
def __str__(self) -> str:
return f"{self}"

def __format__(self, spec):
def __format__(self, spec) -> str:
spec = spec or self._REGISTRY.default_format
return self._REGISTRY.formatter.format_measurement(self, spec)

Expand Down
4 changes: 2 additions & 2 deletions pint/facets/nonmultiplicative/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ class LogarithmicConverter(ScaleConverter):
logfactor: float

@property
def is_multiplicative(self):
def is_multiplicative(self) -> bool:
return False

@property
def is_logarithmic(self):
def is_logarithmic(self) -> bool:
return True

def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
Expand Down
2 changes: 2 additions & 0 deletions pint/facets/nonmultiplicative/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _add_unit(self, definition: UnitDefinition) -> None:
"delta_" + alias for alias in definition.aliases
)

assert isinstance(definition.reference, UnitsContainer)
delta_reference = self.UnitsContainer(
{ref: value for ref, value in definition.reference.items()}
)
Expand Down Expand Up @@ -198,6 +199,7 @@ def _add_ref_of_log_or_offset_unit(

# TODO: Check that reference is None

assert isinstance(slct_ref, UnitsContainer)
# If reference unit is not dimensionless
if slct_ref != UnitsContainer():
# Extract reference unit
Expand Down
Loading
Loading