diff --git a/prody/dynamics/plotting.py b/prody/dynamics/plotting.py index 22a02149c..206d3e173 100644 --- a/prody/dynamics/plotting.py +++ b/prody/dynamics/plotting.py @@ -260,6 +260,8 @@ def showProjection(ensemble, modes, *args, **kwargs): use_weights = kwargs.pop('use_weights', False) weights = kwargs.pop('weights', ensemble.getData('size')) + num = projection.shape[0] + use_labels = kwargs.pop('use_labels', True) labels = kwargs.pop('label', None) if labels is None: @@ -271,11 +273,12 @@ def showProjection(ensemble, modes, *args, **kwargs): labels = modes.getModel()._labels.tolist() LOGGER.info('using labels from LDA model') - if labels is not None and len(labels) != projection.shape[0]: + if labels is not None and len(labels) != num: raise ValueError('label should have the same length as ensemble') c = kwargs.pop('c', 'b') colors = kwargs.pop('color', c) + colors_dict = {} if isinstance(colors, np.ndarray): colors = tuple(colors) if isinstance(colors, (str, tuple)) or colors is None: @@ -291,6 +294,11 @@ def showProjection(ensemble, modes, *args, **kwargs): else: raise TypeError('color must be a string or a list or a dict if labels are provided') + if labels is not None and len(colors_dict) == 0: + cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + for i, label in enumerate(set(labels)): + colors_dict[label] = cycle_colors[i % len(cycle_colors)] + if projection.ndim == 1 or projection.shape[1] == 1: by_time = not kwargs.pop('show_density', True) by_time = kwargs.pop('by_time', by_time) @@ -329,8 +337,6 @@ def showProjection(ensemble, modes, *args, **kwargs): raise ValueError('Projection onto up to 3 modes can be shown. ' 'You have given {0} mode.'.format(len(modes))) - num = projection.shape[0] - markers = kwargs.pop('marker', 'o') if isinstance(markers, str) or markers is None: markers = [markers] * num