diff --git a/src/scanpy/external/pl.py b/src/scanpy/external/pl.py index 662bc88eb3..f387082476 100644 --- a/src/scanpy/external/pl.py +++ b/src/scanpy/external/pl.py @@ -25,6 +25,8 @@ from collections.abc import Collection from typing import Any + from matplotlib.colors import Colormap + __all__ = [ "phate", @@ -166,7 +168,7 @@ def sam( projection: str | np.ndarray = "X_umap", *, c: str | np.ndarray | None = None, - cmap: str = "Spectral_r", + cmap: Colormap | str | None = "Spectral_r", linewidth: float = 0.0, edgecolor: str = "k", axes: Axes | None = None, diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index 6e5c8cd2c5..a60788de43 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -3,10 +3,12 @@ from __future__ import annotations from collections.abc import Mapping -from typing import TYPE_CHECKING, NamedTuple +from dataclasses import KW_ONLY, InitVar, dataclass, field +from typing import TYPE_CHECKING, ClassVar, NamedTuple from warnings import warn import numpy as np +import pandas as pd from legacy_api_wrap import legacy_api from matplotlib import gridspec from matplotlib import pyplot as plt @@ -15,19 +17,25 @@ from .._compat import old_positionals from .._utils import _empty from ._anndata import _get_dendrogram_key, _plot_dendrogram, _prepare_dataframe -from ._utils import check_colornorm, make_grid_spec +from ._utils import ( + ClassDescriptorEnabled, + DefaultProxy, + _dk, + check_colornorm, + make_grid_spec, +) if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Iterable, MutableMapping, Sequence from typing import Literal, Self, Union - import pandas as pd from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize + from matplotlib.figure import Figure from .._utils import Empty - from ._utils import ColorLike, _AxesSubplot + from ._utils import ColorLike _VarNames = Union[str, Sequence[str]] @@ -59,7 +67,8 @@ class VBoundNorm(NamedTuple): """ -class BasePlot: +@dataclass +class BasePlot(metaclass=ClassDescriptorEnabled): """\ Generic class for the visualization of AnnData categories and selected `var` (features or genes). @@ -74,141 +83,137 @@ class BasePlot: BasePlot(adata, ...).legend(title='legend').style(cmap='binary').show() """ - DEFAULT_SAVE_PREFIX = "baseplot_" - MIN_FIGURE_HEIGHT = 2.5 - DEFAULT_CATEGORY_HEIGHT = 0.35 - DEFAULT_CATEGORY_WIDTH = 0.37 - + DEFAULT_SAVE_PREFIX: ClassVar[str] = "baseplot_" + # maximum number of categories allowed to be plotted + MAX_NUM_CATEGORIES: ClassVar[int] = 500 + + adata: AnnData + var_names: _VarNames | Mapping[str, _VarNames] + groupby: str | Sequence[str] + _: KW_ONLY + use_raw: bool | None = None + log: bool = True + num_categories: int = 5 + categories_order: Sequence[str] | None = None + title: str | None = None + figsize: tuple[float, float] | None = None + gene_symbols: str | None = None + var_group_positions: Sequence[tuple[int, int]] | None = None + var_group_labels: Sequence[str] | None = None + var_group_rotation: float | None = None + layer: str | None = None + ax: Axes | None = None + vmin: float | None = None + vmax: float | None = None + vcenter: float | None = None + norm: Normalize | None = None + + # convenience + dendrogram: InitVar[bool | str | None] = None + with_swapped_axes: InitVar[bool] = False + + # minimum height required for legends to plot properly + min_figure_height: float = 2.5 + category_height = 0.35 + category_width = 0.37 # gridspec parameter. Sets the space between mainplot, dendrogram and legend - DEFAULT_WSPACE = 0 - - DEFAULT_COLORMAP = "winter" - DEFAULT_LEGENDS_WIDTH = 1.5 - DEFAULT_COLOR_LEGEND_TITLE = "Expression\nlevel in group" - - MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted - - @old_positionals( - "use_raw", - "log", - "num_categories", - "categories_order", - "title", - "figsize", - "gene_symbols", - "var_group_positions", - "var_group_labels", - "var_group_rotation", - "layer", - "ax", - "vmin", - "vmax", - "vcenter", - "norm", + wspace: float = 0 + cmap: Colormap | str | None = "winter" + legends_width: float = 1.5 + color_legend_title: str = "Expression\nlevel in group" + are_axes_swapped: bool = False + var_names_idx_order: Sequence[int] | None = None + group_extra_size: float = 0.0 + plot_group_extra: dict[str, object] | None = None + # after .render() is called the fig value is assigned and ax_dict + # contains a dictionary of the axes used in the plot + fig: Figure | None = None + ax_dict: dict[str, Axes] | None = None + + kwds: MutableMapping[str, object] = field(default_factory=dict) + + # properties aliasing fields + @property + def has_var_groups(self) -> bool: + return len(self.var_group_positions or ()) > 0 + + @property + def fig_title(self) -> str | None: + return self.title + + @property + def width(self) -> float | None: + return self.figsize[0] if self.figsize is not None else None + + @property + def height(self) -> float | None: + return self.figsize[1] if self.figsize is not None else None + + @property + def vboundnorm(self) -> VBoundNorm: + return VBoundNorm(self.vmin, self.vmax, self.vcenter, self.norm) + + # deprecated class vars + MIN_FIGURE_HEIGHT: ClassVar[DefaultProxy[float]] = DefaultProxy("min_figure_height") + DEFAULT_CATEGORY_HEIGHT: ClassVar[DefaultProxy[float]] = DefaultProxy( + "category_height" + ) + DEFAULT_CATEGORY_WIDTH: ClassVar[DefaultProxy[float]] = DefaultProxy( + "category_width" + ) + DEFAULT_WSPACE: ClassVar[DefaultProxy[float]] = DefaultProxy("wspace") + DEFAULT_COLORMAP: ClassVar[DefaultProxy[Colormap | str | None]] = DefaultProxy( + "cmap" + ) + DEFAULT_LEGENDS_WIDTH: ClassVar[DefaultProxy[float]] = DefaultProxy("legends_width") + DEFAULT_COLOR_LEGEND_TITLE: ClassVar[DefaultProxy[str]] = DefaultProxy( + "color_legend_title" ) - def __init__( - self, - adata: AnnData, - var_names: _VarNames | Mapping[str, _VarNames], - groupby: str | Sequence[str], - *, - use_raw: bool | None = None, - log: bool = False, - num_categories: int = 7, - categories_order: Sequence[str] | None = None, - title: str | None = None, - figsize: tuple[float, float] | None = None, - gene_symbols: str | None = None, - var_group_positions: Sequence[tuple[int, int]] | None = None, - var_group_labels: Sequence[str] | None = None, - var_group_rotation: float | None = None, - layer: str | None = None, - ax: _AxesSubplot | None = None, - vmin: float | None = None, - vmax: float | None = None, - vcenter: float | None = None, - norm: Normalize | None = None, - **kwds, - ): - self.var_names = var_names - self.var_group_labels = var_group_labels - self.var_group_positions = var_group_positions - self.var_group_rotation = var_group_rotation - self.width, self.height = figsize if figsize is not None else (None, None) - - self.has_var_groups = ( - True - if var_group_positions is not None and len(var_group_positions) > 0 - else False - ) + def __post_init__(self, dendrogram: bool | str | None, with_swapped_axes: bool): + cls = type(self) self._update_var_groups() self.categories, self.obs_tidy = _prepare_dataframe( - adata, + self.adata, self.var_names, - groupby, - use_raw=use_raw, - log=log, - num_categories=num_categories, - layer=layer, - gene_symbols=gene_symbols, + self.groupby, + use_raw=self.use_raw, + log=self.log, + num_categories=self.num_categories, + layer=self.layer, + gene_symbols=self.gene_symbols, ) - if len(self.categories) > self.MAX_NUM_CATEGORIES: + if len(self.categories) > cls.MAX_NUM_CATEGORIES: warn( - f"Over {self.MAX_NUM_CATEGORIES} categories found. " + f"Over {cls.MAX_NUM_CATEGORIES} categories found. " "Plot would be very large." ) - if categories_order is not None: - if set(self.obs_tidy.index.categories) != set(categories_order): + if self.categories_order is not None: + assert isinstance(self.obs_tidy.index, pd.CategoricalIndex) + if set(self.obs_tidy.index.categories) != set(self.categories_order): logg.error( "Please check that the categories given by " "the `order` parameter match the categories that " "want to be reordered.\n\n" "Mismatch: " - f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n" - f"Given order categories: {categories_order}\n\n" - f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n" + f"{set(self.obs_tidy.index.categories).difference(self.categories_order)}\n\n" + f"Given order categories: {self.categories_order}\n\n" + f"{self.groupby} categories: {list(self.obs_tidy.index.categories)}\n" ) return - self.adata = adata - self.groupby = [groupby] if isinstance(groupby, str) else groupby - self.log = log - self.kwds = kwds - - self.vboundnorm = VBoundNorm(vmin=vmin, vmax=vmax, vcenter=vcenter, norm=norm) + if isinstance(self.groupby, str): + self.groupby = [self.groupby] - # set default values for legend - self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE - self.legends_width = self.DEFAULT_LEGENDS_WIDTH - - # set style defaults - self.cmap = self.DEFAULT_COLORMAP - - # style default parameters - self.are_axes_swapped = False - self.categories_order = categories_order - self.var_names_idx_order = None - - self.wspace = self.DEFAULT_WSPACE - - # minimum height required for legends to plot properly - self.min_figure_height = self.MIN_FIGURE_HEIGHT - - self.fig_title = title - - self.group_extra_size = 0 - self.plot_group_extra = None - # after .render() is called the fig value is assigned and ax_dict - # contains a dictionary of the axes used in the plot - self.fig = None - self.ax_dict = None - self.ax = ax + if dendrogram: + self.add_dendrogram(dendrogram_key=_dk(dendrogram)) + if with_swapped_axes: + self.swap_axes() @legacy_api("swap_axes") - def swap_axes(self, *, swap_axes: bool | None = True) -> Self: + def swap_axes(self, *, swap_axes: bool = True) -> Self: """ Plots a transposed image. @@ -221,17 +226,16 @@ def swap_axes(self, *, swap_axes: bool | None = True) -> Self: swap_axes Boolean to turn on (True) or off (False) 'swap_axes'. Default True - Returns ------- Returns `self` for method chaining. """ - self.DEFAULT_CATEGORY_HEIGHT, self.DEFAULT_CATEGORY_WIDTH = ( - self.DEFAULT_CATEGORY_WIDTH, - self.DEFAULT_CATEGORY_HEIGHT, + # TODO: this doesn’t make much sense + self.category_height, self.category_width = ( + self.category_width, + self.category_height, ) - self.are_axes_swapped = swap_axes return self @@ -241,7 +245,7 @@ def add_dendrogram( *, show: bool | None = True, dendrogram_key: str | None = None, - size: float | None = 0.8, + size: float = 0.8, ) -> Self: r"""\ Show dendrogram based on the hierarchical clustering between the `groupby` @@ -328,7 +332,7 @@ def add_totals( *, show: bool | None = True, sort: Literal["ascending", "descending"] | None = None, - size: float | None = 0.8, + size: float = 0.8, color: ColorLike | Sequence[ColorLike] | None = None, ) -> Self: r"""\ @@ -429,8 +433,8 @@ def legend( self, *, show: bool | None = True, - title: str | None = DEFAULT_COLOR_LEGEND_TITLE, - width: float | None = DEFAULT_LEGENDS_WIDTH, + title: str | Empty = _empty, + width: float | Empty = _empty, ) -> Self: r"""\ Configure legend parameters @@ -468,14 +472,17 @@ def legend( # turn of legends by setting width to 0 self.legends_width = 0 else: - self.color_legend_title = title - self.legends_width = width + if title is not _empty: + self.color_legend_title = title + if width is not _empty: + self.legends_width = width return self def get_axes(self) -> dict[str, Axes]: if self.ax_dict is None: self.make_figure() + assert self.ax_dict is not None return self.ax_dict def _plot_totals( @@ -670,13 +677,10 @@ def make_figure(self): >>> sc.pl.DotPlot(adata, markers, groupby='bulk_labels', ax=ax1).make_figure() """ - category_height = self.DEFAULT_CATEGORY_HEIGHT - category_width = self.DEFAULT_CATEGORY_WIDTH - if self.height is None: - mainplot_height = len(self.categories) * category_height + mainplot_height = len(self.categories) * self.category_height mainplot_width = ( - len(self.var_names) * category_width + self.group_extra_size + len(self.var_names) * self.category_width + self.group_extra_size ) if self.are_axes_swapped: mainplot_height, mainplot_width = mainplot_width, mainplot_height @@ -685,8 +689,10 @@ def make_figure(self): # if the number of categories is small use # a larger height, otherwise the legends do not fit - self.height = max([self.min_figure_height, height]) - self.width = mainplot_width + self.legends_width + self.figsize = ( + mainplot_width + self.legends_width, + max([self.min_figure_height, height]), + ) else: self.min_figure_height = self.height mainplot_height = self.height @@ -710,9 +716,9 @@ def make_figure(self): if self.has_var_groups: # add some space in case 'brackets' want to be plotted on top of the image if self.are_axes_swapped: - var_groups_height = category_height + var_groups_height = self.category_height else: - var_groups_height = category_height / 2 + var_groups_height = self.category_height / 2 else: var_groups_height = 0 @@ -1122,7 +1128,6 @@ def _update_var_groups(self) -> None: self.var_names = _var_names self.var_group_labels = var_group_labels self.var_group_positions = var_group_positions - self.has_var_groups = True elif isinstance(self.var_names, str): self.var_names = [self.var_names] diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index e2ae434db6..051e16a812 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import KW_ONLY, dataclass +from typing import TYPE_CHECKING, ClassVar import numpy as np from matplotlib import pyplot as plt @@ -12,7 +13,7 @@ from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm from ._utils import ( - _dk, + DefaultProxy, check_colornorm, fix_kwds, make_grid_spec, @@ -34,6 +35,7 @@ @_doc_params(common_plot_args=doc_common_plot_args) +@dataclass class DotPlot(BasePlot): """\ Allows the visualization of two values that are encoded as @@ -94,143 +96,116 @@ class DotPlot(BasePlot): """ - DEFAULT_SAVE_PREFIX = "dotplot_" + DEFAULT_SAVE_PREFIX: ClassVar[str] = "dotplot_" + + _: KW_ONLY + categories_order: Sequence[str] | None = None + expression_cutoff: float = 0.0 + mean_only_expressed: bool = False + standard_scale: Literal["var", "group"] | None = None + dot_color_df: pd.DataFrame | None = None + dot_size_df: pd.DataFrame | None = None + # default style parameters - DEFAULT_COLORMAP = "Reds" - DEFAULT_COLOR_ON = "dot" - DEFAULT_DOT_MAX = None - DEFAULT_DOT_MIN = None - DEFAULT_SMALLEST_DOT = 0.0 - DEFAULT_LARGEST_DOT = 200.0 - DEFAULT_DOT_EDGECOLOR = "black" - DEFAULT_DOT_EDGELW = 0.2 - DEFAULT_SIZE_EXPONENT = 1.5 + cmap: Colormap | str | None = "Reds" # override BasePlot default + color_on: Literal["dot", "square"] = "dot" + dot_max: float | None = None + dot_min: float | None = None + smallest_dot: float = 0.0 + largest_dot: float = 200 + dot_edge_color: ColorLike | None = "black" + dot_edge_lw: float | None = 0.2 + size_exponent: float = 1.5 + grid: bool = False + # a unit is the distance between two x-axis ticks + plot_x_padding: float = 0.8 + # a unit is the distance between two y-axis ticks + plot_y_padding: float = 1.0 # default legend parameters - DEFAULT_SIZE_LEGEND_TITLE = "Fraction of cells\nin group (%)" - DEFAULT_COLOR_LEGEND_TITLE = "Mean expression\nin group" - DEFAULT_LEGENDS_WIDTH = 1.5 # inches - DEFAULT_PLOT_X_PADDING = 0.8 # a unit is the distance between two x-axis ticks - DEFAULT_PLOT_Y_PADDING = 1.0 # a unit is the distance between two y-axis ticks - - @old_positionals( - "use_raw", - "log", - "num_categories", - "categories_order", - "title", - "figsize", - "gene_symbols", - "var_group_positions", - "var_group_labels", - "var_group_rotation", - "layer", - "expression_cutoff", - "mean_only_expressed", - "standard_scale", - "dot_color_df", - "dot_size_df", - "ax", - "vmin", - "vmax", - "vcenter", - "norm", + size_title: str = "Fraction of cells\nin group (%)" + color_legend_title: str = "Mean expression\nin group" + legends_width: float = 1.5 # inches + show_size_legend: bool = True + show_colorbar: bool = True + + # deprecated default class variables + DEFAULT_COLOR_ON: ClassVar[DefaultProxy[Literal["dot", "square"]]] = DefaultProxy( + "color_on" ) - def __init__( - self, - adata: AnnData, - var_names: _VarNames | Mapping[str, _VarNames], - groupby: str | Sequence[str], - *, - use_raw: bool | None = None, - log: bool = False, - num_categories: int = 7, - categories_order: Sequence[str] | None = None, - title: str | None = None, - figsize: tuple[float, float] | None = None, - gene_symbols: str | None = None, - var_group_positions: Sequence[tuple[int, int]] | None = None, - var_group_labels: Sequence[str] | None = None, - var_group_rotation: float | None = None, - layer: str | None = None, - expression_cutoff: float = 0.0, - mean_only_expressed: bool = False, - standard_scale: Literal["var", "group"] | None = None, - dot_color_df: pd.DataFrame | None = None, - dot_size_df: pd.DataFrame | None = None, - ax: _AxesSubplot | None = None, - vmin: float | None = None, - vmax: float | None = None, - vcenter: float | None = None, - norm: Normalize | None = None, - **kwds, + DEFAULT_DOT_MAX: ClassVar[DefaultProxy[float | None]] = DefaultProxy("dot_max") + DEFAULT_DOT_MIN: ClassVar[DefaultProxy[float | None]] = DefaultProxy("dot_min") + DEFAULT_SMALLEST_DOT: ClassVar[DefaultProxy[float]] = DefaultProxy("smallest_dot") + DEFAULT_LARGEST_DOT: ClassVar[DefaultProxy[float]] = DefaultProxy("largest_dot") + DEFAULT_DOT_EDGECOLOR: ClassVar[DefaultProxy[ColorLike | None]] = DefaultProxy( + "dot_edge_color" + ) + DEFAULT_DOT_EDGELW: ClassVar[DefaultProxy[float | None]] = DefaultProxy( + "dot_edge_lw" + ) + DEFAULT_SIZE_EXPONENT: ClassVar[DefaultProxy[float]] = DefaultProxy("size_exponent") + DEFAULT_PLOT_X_PADDING: ClassVar[DefaultProxy[float]] = DefaultProxy( + "plot_x_padding" + ) + DEFAULT_PLOT_Y_PADDING: ClassVar[DefaultProxy[float]] = DefaultProxy( + "plot_y_padding" + ) + DEFAULT_SIZE_LEGEND_TITLE: ClassVar[DefaultProxy[str]] = DefaultProxy("size_title") + + def __post_init__( + self, dendrogram: bool | str | None, with_swapped_axes: bool ) -> None: - BasePlot.__init__( - self, - adata, - var_names, - groupby, - use_raw=use_raw, - log=log, - num_categories=num_categories, - categories_order=categories_order, - title=title, - figsize=figsize, - gene_symbols=gene_symbols, - var_group_positions=var_group_positions, - var_group_labels=var_group_labels, - var_group_rotation=var_group_rotation, - layer=layer, - ax=ax, - vmin=vmin, - vmax=vmax, - vcenter=vcenter, - norm=norm, - **kwds, + super().__post_init__( + dendrogram=dendrogram, with_swapped_axes=with_swapped_axes ) - # for if category defined by groupby (if any) compute for each var_name # 1. the fraction of cells in the category having a value >expression_cutoff # 2. the mean value over the category # 1. compute fraction of cells having value > expression_cutoff # transform obs_tidy into boolean matrix using the expression_cutoff - obs_bool = self.obs_tidy > expression_cutoff + obs_bool = self.obs_tidy > self.expression_cutoff # compute the sum per group which in the boolean matrix this is the number # of values >expression_cutoff, and divide the result by the total number of # values in the group (given by `count()`) - if dot_size_df is None: - dot_size_df = ( + if self.dot_size_df is None: + self.dot_size_df = ( obs_bool.groupby(level=0, observed=True).sum() / obs_bool.groupby(level=0, observed=True).count() ) - if dot_color_df is None: + if self.dot_color_df is None: # 2. compute mean expression value value - if mean_only_expressed: - dot_color_df = ( + if self.mean_only_expressed: + self.dot_color_df = ( self.obs_tidy.mask(~obs_bool) .groupby(level=0, observed=True) .mean() .fillna(0) ) else: - dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() - - if standard_scale == "group": - dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0) - dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - dot_color_df -= dot_color_df.min(0) - dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0) - elif standard_scale is None: + self.dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() + + if self.standard_scale == "group": + self.dot_color_df = self.dot_color_df.sub( + self.dot_color_df.min(1), axis=0 + ) + self.dot_color_df = self.dot_color_df.div( + self.dot_color_df.max(1), axis=0 + ).fillna(0) + elif self.standard_scale == "var": + self.dot_color_df -= self.dot_color_df.min(0) + self.dot_color_df = ( + self.dot_color_df / self.dot_color_df.max(0) + ).fillna(0) + elif self.standard_scale is None: pass else: logg.warning("Unknown type for standard_scale, ignored") else: # check that both matrices have the same shape - if dot_color_df.shape != dot_size_df.shape: + if self.dot_color_df.shape != self.dot_size_df.shape: logg.error( "the given dot_color_df data frame has a different shape than " "the data frame used for the dot size. Both data frames need " @@ -246,45 +221,26 @@ def __init__( # ['a', 'a', 'a', 'a', 'b'] unique_var_names, unique_idx = np.unique( - dot_color_df.columns, return_index=True + self.dot_color_df.columns, return_index=True ) # remove duplicate columns if len(unique_var_names) != len(self.var_names): - dot_color_df = dot_color_df.iloc[:, unique_idx] + self.dot_color_df = self.dot_color_df.iloc[:, unique_idx] # get the same order for rows and columns in the dot_color_df # using the order from the doc_size_df - dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] + self.dot_color_df = self.dot_color_df.loc[self.dot_size_df.index][ + self.dot_size_df.columns + ] self.dot_color_df, self.dot_size_df = ( df.loc[ - categories_order if categories_order is not None else self.categories + self.categories_order + if self.categories_order is not None + else self.categories ] - for df in (dot_color_df, dot_size_df) + for df in (self.dot_color_df, self.dot_size_df) ) - self.standard_scale = standard_scale - - # Set default style parameters - self.cmap = self.DEFAULT_COLORMAP - self.dot_max = self.DEFAULT_DOT_MAX - self.dot_min = self.DEFAULT_DOT_MIN - self.smallest_dot = self.DEFAULT_SMALLEST_DOT - self.largest_dot = self.DEFAULT_LARGEST_DOT - self.color_on = self.DEFAULT_COLOR_ON - self.size_exponent = self.DEFAULT_SIZE_EXPONENT - self.grid = False - self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING - self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING - - self.dot_edge_color = self.DEFAULT_DOT_EDGECOLOR - self.dot_edge_lw = self.DEFAULT_DOT_EDGELW - - # set legend defaults - self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE - self.size_title = self.DEFAULT_SIZE_LEGEND_TITLE - self.legends_width = self.DEFAULT_LEGENDS_WIDTH - self.show_size_legend = True - self.show_colorbar = True @old_positionals( "cmap", @@ -429,8 +385,8 @@ def legend( show_size_legend: bool | None = True, show_colorbar: bool | None = True, size_title: str | None = DEFAULT_SIZE_LEGEND_TITLE, - colorbar_title: str | None = DEFAULT_COLOR_LEGEND_TITLE, - width: float | None = DEFAULT_LEGENDS_WIDTH, + colorbar_title: str | Empty = _empty, + width: float | Empty = _empty, ) -> Self: """\ Configures dot size and the colorbar legends @@ -473,9 +429,11 @@ def legend( # turn of legends by setting width to 0 self.legends_width = 0 else: - self.color_legend_title = colorbar_title + if colorbar_title is not _empty: + self.color_legend_title = colorbar_title self.size_title = size_title - self.legends_width = width + if width is not _empty: + self.legends_width = width self.show_size_legend = show_size_legend self.show_colorbar = show_colorbar @@ -589,7 +547,6 @@ def _mainplot(self, ax: Axes): if self.are_axes_swapped: _size_df = _size_df.T _color_df = _color_df.T - self.cmap = self.kwds.pop("cmap", self.cmap) normalize, dot_min, dot_max = self._dotplot( _size_df, @@ -858,16 +815,16 @@ def dotplot( mean_only_expressed: bool = False, standard_scale: Literal["var", "group"] | None = None, title: str | None = None, - colorbar_title: str | None = DotPlot.DEFAULT_COLOR_LEGEND_TITLE, - size_title: str | None = DotPlot.DEFAULT_SIZE_LEGEND_TITLE, + colorbar_title: str | None = DotPlot.color_legend_title, + size_title: str | None = DotPlot.size_title, figsize: tuple[float, float] | None = None, - dendrogram: bool | str = False, + dendrogram: bool | str | None = None, gene_symbols: str | None = None, var_group_positions: Sequence[tuple[int, int]] | None = None, var_group_labels: Sequence[str] | None = None, var_group_rotation: float | None = None, layer: str | None = None, - swap_axes: bool | None = False, + swap_axes: bool = False, dot_color_df: pd.DataFrame | None = None, show: bool | None = None, save: str | bool | None = None, @@ -878,10 +835,10 @@ def dotplot( vcenter: float | None = None, norm: Normalize | None = None, # Style parameters - cmap: Colormap | str | None = DotPlot.DEFAULT_COLORMAP, - dot_max: float | None = DotPlot.DEFAULT_DOT_MAX, - dot_min: float | None = DotPlot.DEFAULT_DOT_MIN, - smallest_dot: float = DotPlot.DEFAULT_SMALLEST_DOT, + cmap: Colormap | str | None = DotPlot.cmap, + dot_max: float | None = DotPlot.dot_max, + dot_min: float | None = DotPlot.dot_min, + smallest_dot: float = DotPlot.smallest_dot, **kwds, ) -> DotPlot | dict | None: """\ @@ -1012,14 +969,11 @@ def dotplot( vmax=vmax, vcenter=vcenter, norm=norm, - **kwds, + dendrogram=dendrogram, + with_swapped_axes=swap_axes, + kwds=kwds, ) - if dendrogram: - dp.add_dendrogram(dendrogram_key=_dk(dendrogram)) - if swap_axes: - dp.swap_axes() - dp = dp.style( cmap=cmap, dot_max=dot_max, diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index 9184f2455b..1dff269d3e 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import KW_ONLY, InitVar, dataclass +from typing import TYPE_CHECKING, ClassVar, cast import numpy as np +import pandas as pd from matplotlib import pyplot as plt -from matplotlib import rcParams from .. import logging as logg from .._compat import old_positionals @@ -16,13 +17,12 @@ doc_show_save_ax, doc_vboundnorm, ) -from ._utils import _dk, check_colornorm, fix_kwds, savefig_or_show +from ._utils import DefaultProxy, check_colornorm, fix_kwds, savefig_or_show if TYPE_CHECKING: from collections.abc import Mapping, Sequence from typing import Literal, Self - import pandas as pd from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize @@ -33,6 +33,7 @@ @_doc_params(common_plot_args=doc_common_plot_args) +@dataclass class MatrixPlot(BasePlot): """\ Allows the visualization of values using a color map. @@ -90,112 +91,53 @@ class MatrixPlot(BasePlot): sc.pl.MatrixPlot(adata, markers, groupby='bulk_labels').show() """ - DEFAULT_SAVE_PREFIX = "matrixplot_" - DEFAULT_COLOR_LEGEND_TITLE = "Mean expression\nin group" + DEFAULT_SAVE_PREFIX: ClassVar[str] = "matrixplot_" + _: KW_ONLY + colorbar_title: str = "Mean expression\nin group" # default style parameters - DEFAULT_COLORMAP = rcParams["image.cmap"] - DEFAULT_EDGE_COLOR = "gray" - DEFAULT_EDGE_LW = 0.1 - - @old_positionals( - "use_raw", - "log", - "num_categories", - "categories_order", - "title", - "figsize", - "gene_symbols", - "var_group_positions", - "var_group_labels", - "var_group_rotation", - "layer", - "standard_scale", - "ax", - "values_df", - "vmin", - "vmax", - "vcenter", - "norm", - ) - def __init__( + cmap = None # aka: rcParams["image.cmap"] + values_df: pd.DataFrame | None = None + standard_scale: InitVar[Literal["var", "group"] | None] = None + edge_color: ColorLike | None = "gray" + edge_lw: float | None = 0.1 + + # deprecated default class variables + DEFAULT_EDGE_COLOR: ClassVar[DefaultProxy[ColorLike]] = DefaultProxy("edge_color") + DEFAULT_EDGE_LW: ClassVar[DefaultProxy[float]] = DefaultProxy("edge_lw") + + def __post_init__( self, - adata: AnnData, - var_names: _VarNames | Mapping[str, _VarNames], - groupby: str | Sequence[str], - *, - use_raw: bool | None = None, - log: bool = False, - num_categories: int = 7, - categories_order: Sequence[str] | None = None, - title: str | None = None, - figsize: tuple[float, float] | None = None, - gene_symbols: str | None = None, - var_group_positions: Sequence[tuple[int, int]] | None = None, - var_group_labels: Sequence[str] | None = None, - var_group_rotation: float | None = None, - layer: str | None = None, - standard_scale: Literal["var", "group"] | None = None, - ax: _AxesSubplot | None = None, - values_df: pd.DataFrame | None = None, - vmin: float | None = None, - vmax: float | None = None, - vcenter: float | None = None, - norm: Normalize | None = None, - **kwds, + dendrogram: bool | str | None, + with_swapped_axes: bool, + standard_scale: Literal["var", "group"] | None, ): - BasePlot.__init__( - self, - adata, - var_names, - groupby, - use_raw=use_raw, - log=log, - num_categories=num_categories, - categories_order=categories_order, - title=title, - figsize=figsize, - gene_symbols=gene_symbols, - var_group_positions=var_group_positions, - var_group_labels=var_group_labels, - var_group_rotation=var_group_rotation, - layer=layer, - ax=ax, - vmin=vmin, - vmax=vmax, - vcenter=vcenter, - norm=norm, - **kwds, + super().__post_init__( + dendrogram=dendrogram, with_swapped_axes=with_swapped_axes + ) + if self.values_df is not None: + return + + # compute mean value + self.values_df = cast( + pd.DataFrame, + self.obs_tidy.groupby(level=0, observed=True) + .mean() + .loc[ + self.categories_order + if self.categories_order is not None + else self.categories + ], ) - if values_df is None: - # compute mean value - values_df = ( - self.obs_tidy.groupby(level=0, observed=True) - .mean() - .loc[ - self.categories_order - if self.categories_order is not None - else self.categories - ] - ) - - if standard_scale == "group": - values_df = values_df.sub(values_df.min(1), axis=0) - values_df = values_df.div(values_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - values_df -= values_df.min(0) - values_df = (values_df / values_df.max(0)).fillna(0) - elif standard_scale is None: - pass - else: - logg.warning("Unknown type for standard_scale, ignored") - - self.values_df = values_df - - self.cmap = self.DEFAULT_COLORMAP - self.edge_color = self.DEFAULT_EDGE_COLOR - self.edge_lw = self.DEFAULT_EDGE_LW + if standard_scale == "group": + self.values_df = self.values_df.sub(self.values_df.min(1), axis=0) + self.values_df = self.values_df.div(self.values_df.max(1), axis=0).fillna(0) + elif standard_scale == "var": + self.values_df -= self.values_df.min(0) + self.values_df = (self.values_df / self.values_df.max(0)).fillna(0) + elif standard_scale is not None: + logg.warning("Unknown type for standard_scale, ignored") def style( self, @@ -347,7 +289,7 @@ def matrixplot( num_categories: int = 7, categories_order: Sequence[str] | None = None, figsize: tuple[float, float] | None = None, - dendrogram: bool | str = False, + dendrogram: bool | str | None = None, title: str | None = None, cmap: Colormap | str | None = MatrixPlot.DEFAULT_COLORMAP, colorbar_title: str | None = MatrixPlot.DEFAULT_COLOR_LEGEND_TITLE, @@ -454,14 +396,11 @@ def matrixplot( vmax=vmax, vcenter=vcenter, norm=norm, - **kwds, + dendrogram=dendrogram, + with_swapped_axes=swap_axes, + kwds=kwds, ) - if dendrogram: - mp.add_dendrogram(dendrogram_key=_dk(dendrogram)) - if swap_axes: - mp.swap_axes() - mp = mp.style(cmap=cmap).legend(title=colorbar_title) if return_fig: return mp diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 691dd863d0..18496fa214 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from dataclasses import KW_ONLY, InitVar, dataclass +from typing import TYPE_CHECKING, ClassVar import numpy as np import pandas as pd @@ -16,8 +17,8 @@ from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm from ._utils import ( + DefaultProxy, _deprecated_scale, - _dk, check_colornorm, make_grid_spec, savefig_or_show, @@ -37,6 +38,7 @@ @_doc_params(common_plot_args=doc_common_plot_args) +@dataclass class StackedViolin(BasePlot): """\ Stacked violin plots. @@ -102,29 +104,61 @@ class StackedViolin(BasePlot): >>> adata = sc.datasets.pbmc68k_reduced() >>> markers = ['C1QA', 'PSAP', 'CD79A', 'CD79B', 'CST3', 'LYZ'] >>> sc.pl.StackedViolin(adata, markers, groupby='bulk_labels', dendrogram=True) # doctest: +ELLIPSIS - + StackedViolin(adata=AnnData...) Using var_names as dict: >>> markers = {{'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'}} >>> sc.pl.StackedViolin(adata, markers, groupby='bulk_labels', dendrogram=True) # doctest: +ELLIPSIS - + StackedViolin(adata=AnnData...) """ - DEFAULT_SAVE_PREFIX = "stacked_violin_" - DEFAULT_COLOR_LEGEND_TITLE = "Median expression\nin group" + DEFAULT_SAVE_PREFIX: ClassVar[str] = "stacked_violin_" + + _: KW_ONLY + standard_scale: InitVar[Literal["var", "group"] | None] = None + + # overrides + color_legend_title: str = "Median expression\nin group" + cmap: Colormap | str | None = "Blues" + + # style parameters + row_palette: str | None = None + stripplot: bool = False + jitter: float | bool = False + jitter_size: int | float = 1 + plot_yticklabels: bool = False + ylim: tuple[float, float] | None = None + # a unit is the distance between two x-axis ticks + plot_x_padding: float = 0.5 + # a unit is the distance between two y-axis ticks + plot_y_padding: float = 0.5 + + # deprecated default class variables + DEFAULT_ROW_PALETTE: ClassVar[DefaultProxy[str | None]] = DefaultProxy( + "row_palette" + ) + DEFAULT_STRIPPLOT: ClassVar[DefaultProxy[bool]] = DefaultProxy("stripplot") + DEFAULT_JITTER: ClassVar[DefaultProxy[float | bool]] = DefaultProxy("jitter") + DEFAULT_JITTER_SIZE: ClassVar[DefaultProxy[int | float]] = DefaultProxy( + "jitter_size" + ) + DEFAULT_PLOT_YTICKLABELS: ClassVar[DefaultProxy[bool]] = DefaultProxy( + "plot_yticklabels" + ) + DEFAULT_YLIM: ClassVar[DefaultProxy[tuple[float, float] | None]] = DefaultProxy( + "ylim" + ) + DEFAULT_PLOT_X_PADDING: ClassVar[DefaultProxy[float]] = DefaultProxy( + "plot_x_padding" + ) + DEFAULT_PLOT_Y_PADDING: ClassVar[DefaultProxy[float]] = DefaultProxy( + "plot_y_padding" + ) - DEFAULT_COLORMAP = "Blues" - DEFAULT_STRIPPLOT = False - DEFAULT_JITTER = False - DEFAULT_JITTER_SIZE = 1 + # kwds defaults: TODO: make work with proxys DEFAULT_LINE_WIDTH = 0.2 - DEFAULT_ROW_PALETTE = None DEFAULT_DENSITY_NORM: DensityNorm = "width" - DEFAULT_PLOT_YTICKLABELS = False - DEFAULT_YLIM = None - DEFAULT_PLOT_X_PADDING = 0.5 # a unit is the distance between two x-axis ticks - DEFAULT_PLOT_Y_PADDING = 0.5 # a unit is the distance between two y-axis ticks # set by default the violin plot cut=0 to limit the extend # of the violin plot as this produces better plots that wont extend @@ -143,86 +177,15 @@ class StackedViolin(BasePlot): # None will draw unadorned violins. DEFAULT_INNER = None - def __getattribute__(self, name: str) -> object: - """Called unconditionally when accessing an instance attribute""" - # If the user has set the deprecated version on the class, - # and our code accesses the new version from the instance, - # return the user-specified version instead and warn. - # This is done because class properties are hard to do. - if name == "DEFAULT_DENSITY_NORM" and hasattr(self, "DEFAULT_SCALE"): - msg = "Don’t set DEFAULT_SCALE, use DEFAULT_DENSITY_NORM instead" - warnings.warn(msg, FutureWarning) - return object.__getattribute__(self, "DEFAULT_SCALE") - return object.__getattribute__(self, name) - - @old_positionals( - "use_raw", - "log", - "num_categories", - "categories_order", - "title", - "figsize", - "gene_symbols", - "var_group_positions", - "var_group_labels", - "var_group_rotation", - "layer", - "standard_scale", - "ax", - "vmin", - "vmax", - "vcenter", - "norm", - ) - def __init__( + def __post_init__( self, - adata: AnnData, - var_names: _VarNames | Mapping[str, _VarNames], - groupby: str | Sequence[str], - *, - use_raw: bool | None = None, - log: bool = False, - num_categories: int = 7, - categories_order: Sequence[str] | None = None, - title: str | None = None, - figsize: tuple[float, float] | None = None, - gene_symbols: str | None = None, - var_group_positions: Sequence[tuple[int, int]] | None = None, - var_group_labels: Sequence[str] | None = None, - var_group_rotation: float | None = None, - layer: str | None = None, - standard_scale: Literal["var", "group"] | None = None, - ax: _AxesSubplot | None = None, - vmin: float | None = None, - vmax: float | None = None, - vcenter: float | None = None, - norm: Normalize | None = None, - **kwds, + dendrogram: bool | str | None, + with_swapped_axes: bool, + standard_scale: Literal["var", "group"] | None, ): - BasePlot.__init__( - self, - adata, - var_names, - groupby, - use_raw=use_raw, - log=log, - num_categories=num_categories, - categories_order=categories_order, - title=title, - figsize=figsize, - gene_symbols=gene_symbols, - var_group_positions=var_group_positions, - var_group_labels=var_group_labels, - var_group_rotation=var_group_rotation, - layer=layer, - ax=ax, - vmin=vmin, - vmax=vmax, - vcenter=vcenter, - norm=norm, - **kwds, + super().__post_init__( + dendrogram=dendrogram, with_swapped_axes=with_swapped_axes ) - if standard_scale == "obs": standard_scale = "group" msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" @@ -233,22 +196,10 @@ def __init__( elif standard_scale == "var": self.obs_tidy -= self.obs_tidy.min(0) self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(0)).fillna(0) - elif standard_scale is None: - pass - else: + elif standard_scale is not None: logg.warning("Unknown type for standard_scale, ignored") - # Set default style parameters - self.cmap = self.DEFAULT_COLORMAP - self.row_palette = self.DEFAULT_ROW_PALETTE - self.stripplot = self.DEFAULT_STRIPPLOT - self.jitter = self.DEFAULT_JITTER - self.jitter_size = self.DEFAULT_JITTER_SIZE - self.plot_yticklabels = self.DEFAULT_PLOT_YTICKLABELS - self.ylim = self.DEFAULT_YLIM - self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING - self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING - + self.kwds = dict(self.kwds) self.kwds.setdefault("cut", self.DEFAULT_CUT) self.kwds.setdefault("inner", self.DEFAULT_INNER) self.kwds.setdefault("linewidth", self.DEFAULT_LINE_WIDTH) @@ -676,9 +627,9 @@ def stacked_violin( use_raw: bool | None = None, num_categories: int = 7, title: str | None = None, - colorbar_title: str | None = StackedViolin.DEFAULT_COLOR_LEGEND_TITLE, + colorbar_title: str | None = StackedViolin.color_legend_title, figsize: tuple[float, float] | None = None, - dendrogram: bool | str = False, + dendrogram: bool | str | None = None, gene_symbols: str | None = None, var_group_positions: Sequence[tuple[int, int]] | None = None, var_group_labels: Sequence[str] | None = None, @@ -832,13 +783,11 @@ def stacked_violin( vmax=vmax, vcenter=vcenter, norm=norm, - **kwds, + dendrogram=dendrogram, + with_swapped_axes=swap_axes, + kwds=kwds, ) - if dendrogram: - vp.add_dendrogram(dendrogram_key=_dk(dendrogram)) - if swap_axes: - vp.swap_axes() vp = vp.style( cmap=cmap, stripplot=stripplot, diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index ea6aa0cb10..101e075eac 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -2,7 +2,17 @@ import warnings from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Callable, Literal, TypedDict, Union, overload +from dataclasses import MISSING, Field, dataclass +from typing import ( + TYPE_CHECKING, + Callable, + Generic, + Literal, + TypedDict, + TypeVar, + Union, + overload, +) import matplotlib as mpl import numpy as np @@ -35,6 +45,10 @@ # TODO: more DensityNorm = Literal["area", "count", "width"] + O = TypeVar("O") + + +T = TypeVar("T") # These are needed by _wraps_plot_scatter VBound = Union[str, float, Callable[[Sequence[float]], float]] @@ -66,6 +80,59 @@ class _AxesSubplot(Axes, axes.SubplotBase): """Intersection between Axes and SubplotBase: Has methods of both""" +class ClassDescriptorEnabled(type): + """Metaclass to allow descriptors’ `__set__` to be called when updating a class attribute. + + `DefaultProxy` below relies on that. + """ + + def __setattr__(cls, name: str, value: object) -> None: + desc = cls.__dict__.get(name) + if desc is not None and hasattr(type(desc), "__set__"): + return desc.__set__(None, value) + return super().__setattr__(name, value) + + +@dataclass +class DefaultProxy(Generic[T]): + attr: str + cls: type = object # O, set automatically by __set_name__ + name: str = "" # ditto + + def __set_name__(self, owner: type, name: str) -> None: + self.cls = owner + self.name = name + + def __get__(self, obj: O | None, objtype: type[O] | None = None) -> T: + if objtype is None: + if obj is None: + msg = f"Weird access to {self}" + raise AttributeError(msg) + objtype = type(obj) + + v = getattr(objtype, self.attr) + if isinstance(v, Field): + if v.default is not MISSING: + v = v.default + elif v.default_factory is not MISSING: + v = v.default_factory() + else: + raise AttributeError( + f"Field {self.attr} of class {objtype} has no default value" + ) + return v + + def __set__(self, obj: object | None, value: T) -> None: + if obj is None: # This is enabled by `ClassDescriptorEnabled` above + msg = ( + f"Subclass {self.cls.__name__} or " + f"use `functools.partial` to override {self.attr}." + ) + warnings.warn(msg, FutureWarning) + obj = self.cls + setattr(obj, self.attr, value) + + # ------------------------------------------------------------------------------- # Simple plotting functions # ------------------------------------------------------------------------------- diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 92eb61c252..ae889a2c4d 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -374,7 +374,7 @@ def test_dotplot_style_no_reset(): pbmc = pbmc68k_reduced() plot = sc.pl.dotplot(pbmc, "CD79A", "bulk_labels", return_fig=True) assert isinstance(plot, sc.pl.DotPlot) - assert plot.cmap == sc.pl.DotPlot.DEFAULT_COLORMAP + assert plot.cmap == sc.pl.DotPlot.cmap plot.style(cmap="winter") assert plot.cmap == "winter" plot.style(color_on="square") @@ -1733,10 +1733,3 @@ def test_string_mask(tmp_path, check_same_image): plt.close() check_same_image(p1, p2, tol=1) - - -def test_violin_scale_warning(monkeypatch): - adata = pbmc3k_processed() - monkeypatch.setattr(sc.pl.StackedViolin, "DEFAULT_SCALE", "count", raising=False) - with pytest.warns(FutureWarning, match="Don’t set DEFAULT_SCALE"): - sc.pl.StackedViolin(adata, adata.var_names[:3], groupby="louvain") diff --git a/tests/test_plotting_utils.py b/tests/test_plotting_utils.py index 6b53cd5b50..6e79a26c8d 100644 --- a/tests/test_plotting_utils.py +++ b/tests/test_plotting_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import cast +from dataclasses import dataclass, field +from typing import ClassVar, cast import numpy as np import pytest @@ -8,7 +9,11 @@ from matplotlib import colormaps from matplotlib.colors import ListedColormap -from scanpy.plotting._utils import _validate_palette +from scanpy.plotting._utils import ( + ClassDescriptorEnabled, + DefaultProxy, + _validate_palette, +) viridis = cast(ListedColormap, colormaps["viridis"]) @@ -27,3 +32,53 @@ def test_validate_palette_no_mod(palette, typ): adata = AnnData(uns=dict(test_colors=palette)) _validate_palette(adata, "test") assert palette is adata.uns["test_colors"], "Palette should not be modified" + + +@pytest.mark.parametrize( + "param", + [ + pytest.param(1, id="direct"), + pytest.param(field(default=1), id="default"), + pytest.param( + field(default_factory=lambda: 1), + marks=[ + pytest.mark.xfail( + reason="Tries to call factory while class not fully constructed" + ) + ], + id="default_factory", + ), + ], +) +@pytest.mark.parametrize("set_", ["instance", "field_", "DEFAULT"]) +def test_default_proxy(param, set_: str): + @dataclass + class Test(metaclass=ClassDescriptorEnabled): + field_: int = param + DEFAULT: ClassVar[DefaultProxy[int]] = DefaultProxy("field_") + + instance = Test(2) + assert instance.field_ == 2 + # instantiating doesn’t update the class + assert instance.DEFAULT == Test().field_ == 1 + + instance.field_ = 3 + # updating the instance doesn’t update the class + assert Test.field_ == Test.DEFAULT == 1 + + if set_ == "instance": + v = 1 + elif set_ == "field_": + Test.field_ = v = 4 + elif set_ == "DEFAULT": + with pytest.warns(FutureWarning): + Test.DEFAULT = v = 5 + else: + pytest.fail(f"Unknown {set_=}") + + # updating anything doesn’t update existing instances + assert instance.field_ == 3 + # setting the fields updates the class, but … + assert Test.field_ == Test.DEFAULT == v + # … sadly doesn’t update the __init__ method + assert Test().field_ == 1