-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug fix on color handling for showProjection #1070
base: main
Are you sure you want to change the base?
Changes from all commits
492c3ad
8325bb4
f1f296f
784aa3f
06842d3
06980e7
9e5eec4
6708a3b
5bac96c
85e7c5f
a9ccb56
43bc8e3
6e8af98
ff6384a
dc10bca
fdfb2ae
3453d21
8a49ba4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,12 @@ | |
and keyword arguments are passed to the Matplotlib functions.""" | ||
|
||
from collections import defaultdict | ||
|
||
from numbers import Number | ||
import numpy as np | ||
|
||
from prody import LOGGER, SETTINGS, PY3K | ||
from prody.utilities import showFigure, addEnds, showMatrix | ||
from prody.utilities import showFigure, addEnds, showMatrix, isListLike | ||
from prody.atomic import AtomGroup, Selection, Atomic, sliceAtoms, sliceAtomicData | ||
|
||
from .nma import NMA | ||
|
@@ -218,13 +219,20 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs): | |
Default is to use ensemble.getData('size') | ||
:type weights: int, list, :class:`~numpy.ndarray` | ||
|
||
:keyword color: a color name or a list of color names or values, | ||
:keyword color: a color name or value or a list of length ensemble.numConfs() or projection.shape[0] of these, | ||
or a dictionary with these with keys corresponding to labels provided by keyword label | ||
default is ``'blue'`` | ||
Color values can have 1 element to be mapped with cmap or 3 as RGB or 4 as RGBA. | ||
See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def | ||
:type color: str, list | ||
|
||
:keyword label: label or a list of labels | ||
:type label: str, list | ||
|
||
:keyword use_labels: whether to use labels for coloring subsets. | ||
These can also be taken from an LDA or LRA model. | ||
:type use_labels: bool | ||
|
||
:keyword marker: a marker or a list of markers, default is ``'o'`` | ||
:type marker: str, list | ||
|
||
|
@@ -278,31 +286,17 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs): | |
if labels is None and use_labels and modes is not None: | ||
if isinstance(modes, (LDA, LRA)): | ||
labels = modes._labels.tolist() | ||
LOGGER.info('using labels from LDA modes') | ||
LOGGER.info('using labels from {0} modes'.format(type(modes))) | ||
elif isinstance(modes.getModel(), (LDA, LRA)): | ||
labels = modes.getModel()._labels.tolist() | ||
LOGGER.info('using labels from LDA model') | ||
LOGGER.info('using labels from {0} modes'.format(type(modes.getModel()))) | ||
|
||
if labels is not None and len(labels) != num: | ||
raise ValueError('label should have the same length as ensemble') | ||
|
||
c = kwargs.pop('c', 'b') | ||
colors = kwargs.pop('color', c) | ||
colors_dict = {} | ||
if isinstance(colors, np.ndarray): | ||
colors = tuple(colors) | ||
if isinstance(colors, (str, tuple)) or colors is None: | ||
colors = [colors] * num | ||
elif isinstance(colors, list): | ||
if len(colors) != num: | ||
raise ValueError('length of color must be {0}'.format(num)) | ||
elif isinstance(colors, dict): | ||
if labels is None: | ||
raise TypeError('color must be a string or a list unless labels are provided') | ||
colors_dict = colors | ||
colors = [colors_dict[label] for label in labels] | ||
else: | ||
raise TypeError('color must be a string or a list or a dict if labels are provided') | ||
colors, colors_dict = checkColors(colors, num, labels, allowNumbers=True) | ||
|
||
if labels is not None and len(colors_dict) == 0: | ||
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] | ||
|
@@ -507,7 +501,8 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): | |
:keyword scalar: scalar factor for projection onto selected mode | ||
:type scalar: float | ||
|
||
:keyword color: a color name or a list of color name, default is ``'blue'`` | ||
:keyword color: a color spec or a list of color specs, default is ``'blue'`` | ||
See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def | ||
:type color: str, list | ||
|
||
:keyword label: label or a list of labels | ||
|
@@ -556,13 +551,6 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): | |
raise TypeError('marker must be a string or a list') | ||
|
||
colors = kwargs.pop('color', 'blue') | ||
if isinstance(colors, str) or colors is None: | ||
colors = [colors] * num | ||
elif isinstance(colors, list): | ||
if len(colors) != num: | ||
raise ValueError('length of color must be {0}'.format(num)) | ||
else: | ||
raise TypeError('color must be a string or a list') | ||
|
||
labels = kwargs.pop('label', None) | ||
if isinstance(labels, str) or labels is None: | ||
|
@@ -575,21 +563,7 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs): | |
|
||
kwargs['ls'] = kwargs.pop('linestyle', None) or kwargs.pop('ls', 'None') | ||
|
||
colors_dict = {} | ||
if isinstance(colors, np.ndarray): | ||
colors = tuple(colors) | ||
if isinstance(colors, (str, tuple)) or colors is None: | ||
colors = [colors] * num | ||
elif isinstance(colors, list): | ||
if len(colors) != num: | ||
raise ValueError('length of color must be {0}'.format(num)) | ||
elif isinstance(colors, dict): | ||
if labels is None: | ||
raise TypeError('color must be a string or a list unless labels are provided') | ||
colors_dict = colors | ||
colors = [colors_dict[label] for label in labels] | ||
else: | ||
raise TypeError('color must be a string or a list or a dict if labels are provided') | ||
colors, colors_dict = checkColors(colors, num, labels) | ||
|
||
if labels is not None and len(colors_dict) == 0: | ||
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] | ||
|
@@ -2381,3 +2355,39 @@ def showTree_networkx(tree, node_size=20, node_color='red', node_shape='o', | |
showFigure() | ||
|
||
return mpl.gca() | ||
|
||
|
||
def checkColors(colors, num, labels, allowNumbers=False): | ||
"""Check colors and process them if needed""" | ||
|
||
from matplotlib.colors import is_color_like | ||
|
||
colors_dict = {} | ||
|
||
if is_color_like(colors) or colors is None or isinstance(colors, Number): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it would fail so I added it down there too, because we do want to allow it |
||
colors = [colors] * num | ||
elif isListLike(colors): | ||
colors = list(colors) | ||
elif isinstance(colors, dict): | ||
if labels is None: | ||
raise TypeError('color cannot be a dict unless labels are provided') | ||
colors_dict = colors | ||
colors = [colors_dict[label] for label in labels] | ||
|
||
if isinstance(colors, list): | ||
if len(colors) != num: | ||
raise ValueError('colors should have the length of the set to be colored or satisfy matplotlib color rules') | ||
|
||
for color in colors: | ||
if not is_color_like(color) and color is not None: | ||
if not allowNumbers: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the entire loop may look like this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's a conversion of the numbers with the color cycle somewhere else, but yes, we could maybe move it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, we use matplotlib.colors.Normalize instead to link up with the cmap. This is on lines 356 and 428 This could somehow be incorporated too though |
||
raise ValueError('each element of colors should satisfy matplotlib color rules') | ||
elif not isinstance(color, Number): | ||
raise ValueError('each element of colors should be a number or satisfy matplotlib color rules') | ||
|
||
if not isinstance(color, type(colors[0])): | ||
raise TypeError('each element of colors should have the same type') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this check necessary? Couldn't the matplotlib function handle colors defined in different ways in the same list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I guess it probably could |
||
else: | ||
raise TypeError('colors should be a color spec or convertible to a list of color specs') | ||
|
||
return colors, colors_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this section can go away and
checkColors
doesn't need to returncolor_dict
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a case where we need colors_dict on line 317 where we make a line graph. We'll have to find a way to adjust that.