Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore one label and markersize to showProjection #2007

Merged
merged 2 commits into from
Nov 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions prody/dynamics/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
weights = kwargs.pop('weights', None)
weights = None

markersize = kwargs.pop('markersize', None)

num = projection.shape[0]

use_labels = kwargs.pop('use_labels', True)
Expand All @@ -283,8 +285,14 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
labels = modes.getModel()._labels.tolist()
LOGGER.info('using labels from LDA model')

if labels is not None and len(labels) != num:
raise ValueError('label should have the same length as ensemble')
one_label = False
if labels is not None:
if len(labels) == 1 or np.isscalar(labels):
one_label = True
kwargs['label'] = labels

elif len(labels) != num:
raise ValueError('label should have the same length as ensemble')

c = kwargs.pop('c', 'b')
colors = kwargs.pop('color', c)
Expand All @@ -297,14 +305,14 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
if len(colors) != num:
raise ValueError('length of color must be {0}'.format(num))
elif isinstance(colors, dict):
if labels is None:
if labels is None or one_label:
raise TypeError('color must be a string or a list unless labels are provided')
colors_dict = colors
colors = [colors_dict[label] for label in labels]
else:
raise TypeError('color must be a string or a list or a dict if labels are provided')

if labels is not None and len(colors_dict) == 0:
if labels is not None and not one_label and len(colors_dict) == 0:
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for i, label in enumerate(set(labels)):
colors_dict[label] = cycle_colors[i % len(cycle_colors)]
Expand All @@ -318,6 +326,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
show = plt.plot(range(len(projection)), projection.flatten(), *args, **kwargs)
if use_weights:
kwargs['s'] = weights
elif markersize is not None:
kwargs['s'] = markersize
if labels is not None and use_labels:
for label in set(labels):
kwargs['c'] = colors_dict[label]
Expand Down Expand Up @@ -444,6 +454,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
kwargs['c'] = color
if weights is not None and use_weights:
kwargs['s'] = weights
elif markersize is not None:
kwargs['s'] = markersize
plot(*(list(projection[indices].T) + args), **kwargs)
else:
kwargs['color'] = color
Expand Down
Loading