diff --git a/prody/dynamics/plotting.py b/prody/dynamics/plotting.py index 3a87ee4dc..70f017491 100644 --- a/prody/dynamics/plotting.py +++ b/prody/dynamics/plotting.py @@ -7,11 +7,12 @@ and keyword arguments are passed to the Matplotlib functions.""" from collections import defaultdict + from numbers import Number import numpy as np from prody import LOGGER, SETTINGS, PY3K -from prody.utilities import showFigure, addEnds, showMatrix +from prody.utilities import showFigure, addEnds, showMatrix, isListLike from prody.atomic import AtomGroup, Selection, Atomic, sliceAtoms, sliceAtomicData from .nma import NMA @@ -218,13 +219,20 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs): Default is to use ensemble.getData('size') :type weights: int, list, :class:`~numpy.ndarray` - :keyword color: a color name or a list of color names or values, + :keyword color: a color name or value or a list of length ensemble.numConfs() or projection.shape[0] of these, + or a dictionary with these with keys corresponding to labels provided by keyword label default is ``'blue'`` + Color values can have 1 element to be mapped with cmap or 3 as RGB or 4 as RGBA. + See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def :type color: str, list :keyword label: label or a list of labels :type label: str, list + :keyword use_labels: whether to use labels for coloring subsets. + These can also be taken from an LDA or LRA model. + :type use_labels: bool + :keyword marker: a marker or a list of markers, default is ``'o'`` :type marker: str, list @@ -278,31 +286,17 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs): if labels is None and use_labels and modes is not None: if isinstance(modes, (LDA, LRA)): labels = modes._labels.tolist() - LOGGER.info('using labels from LDA modes') + LOGGER.info('using labels from {0} modes'.format(type(modes))) elif isinstance(modes.getModel(), (LDA, LRA)): labels = modes.getModel()._labels.tolist() - LOGGER.info('using labels from LDA model') + LOGGER.info('using labels from {0} modes'.format(type(modes.getModel()))) if labels is not None and len(labels) != num: raise ValueError('label should have the same length as ensemble') c = kwargs.pop('c', 'b') colors = kwargs.pop('color', c) - colors_dict = {} - if isinstance(colors, np.ndarray): - colors = tuple(colors) - if isinstance(colors, (str, tuple)) or colors is None: - colors = [colors] * num - elif isinstance(colors, list): - if len(colors) != num: - raise ValueError('length of color must be {0}'.format(num)) - elif isinstance(colors, dict): - if labels is None: - raise TypeError('color must be a string or a list unless labels are provided') - colors_dict = colors - colors = [colors_dict[label] for label in labels] - else: - raise TypeError('color must be a string or a list or a dict if labels are provided') + colors, colors_dict = checkColors(colors, num, labels, allowNumbers=True) if labels is not None and len(colors_dict) == 0: cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] @@ -507,7 +501,8 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): :keyword scalar: scalar factor for projection onto selected mode :type scalar: float - :keyword color: a color name or a list of color name, default is ``'blue'`` + :keyword color: a color spec or a list of color specs, default is ``'blue'`` + See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def :type color: str, list :keyword label: label or a list of labels @@ -556,13 +551,6 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): raise TypeError('marker must be a string or a list') colors = kwargs.pop('color', 'blue') - if isinstance(colors, str) or colors is None: - colors = [colors] * num - elif isinstance(colors, list): - if len(colors) != num: - raise ValueError('length of color must be {0}'.format(num)) - else: - raise TypeError('color must be a string or a list') labels = kwargs.pop('label', None) if isinstance(labels, str) or labels is None: @@ -575,21 +563,7 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): kwargs['ls'] = kwargs.pop('linestyle', None) or kwargs.pop('ls', 'None') - colors_dict = {} - if isinstance(colors, np.ndarray): - colors = tuple(colors) - if isinstance(colors, (str, tuple)) or colors is None: - colors = [colors] * num - elif isinstance(colors, list): - if len(colors) != num: - raise ValueError('length of color must be {0}'.format(num)) - elif isinstance(colors, dict): - if labels is None: - raise TypeError('color must be a string or a list unless labels are provided') - colors_dict = colors - colors = [colors_dict[label] for label in labels] - else: - raise TypeError('color must be a string or a list or a dict if labels are provided') + colors, colors_dict = checkColors(colors, num, labels) if labels is not None and len(colors_dict) == 0: cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] @@ -2381,3 +2355,39 @@ def showTree_networkx(tree, node_size=20, node_color='red', node_shape='o', showFigure() return mpl.gca() + + +def checkColors(colors, num, labels, allowNumbers=False): + """Check colors and process them if needed""" + + from matplotlib.colors import is_color_like + + colors_dict = {} + + if is_color_like(colors) or colors is None or isinstance(colors, Number): + colors = [colors] * num + elif isListLike(colors): + colors = list(colors) + elif isinstance(colors, dict): + if labels is None: + raise TypeError('color cannot be a dict unless labels are provided') + colors_dict = colors + colors = [colors_dict[label] for label in labels] + + if isinstance(colors, list): + if len(colors) != num: + raise ValueError('colors should have the length of the set to be colored or satisfy matplotlib color rules') + + for color in colors: + if not is_color_like(color) and color is not None: + if not allowNumbers: + raise ValueError('each element of colors should satisfy matplotlib color rules') + elif not isinstance(color, Number): + raise ValueError('each element of colors should be a number or satisfy matplotlib color rules') + + if not isinstance(color, type(colors[0])): + raise TypeError('each element of colors should have the same type') + else: + raise TypeError('colors should be a color spec or convertible to a list of color specs') + + return colors, colors_dict