Skip to content

Commit

Permalink
enh : implement a new feature to perform hierarchical clustering on t…
Browse files Browse the repository at this point in the history
…he correlation matrix before plotting it
  • Loading branch information
celprov committed Mar 24, 2023
1 parent f133a87 commit a373e37
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions mriqc_learn/viz/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def plot_corrmat(
cbarlabel="",
symmetric=True,
figsize=None,
sort=False,
**kwargs,
):
"""
Expand All @@ -204,14 +205,40 @@ def plot_corrmat(
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
cbarlabel
The label for the colorbar. Optional.
sort
Flag to perform hierachical clustering on the correlation plot
**kwargs
All other arguments are forwarded to `imshow`.
"""
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Cluster rows (if arguments enabled)
if sort:
import pandas as pd
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster

Z = linkage(data, 'complete', optimal_ordering=True)

dendrogram(Z, labels=data.columns, no_plot=True)

# Clusterize the data
threshold = 0.1
labels = fcluster(Z, threshold, criterion='distance')
# Keep the indices to sort labels
labels_order = np.argsort(labels)

# Build a new dataframe with the sorted columns
for idx, i in enumerate(data.columns[labels_order]):
if idx == 0:
clustered = pd.DataFrame(data[i])
else:
df_to_append = pd.DataFrame(data[i])
clustered = pd.concat([clustered, df_to_append], axis=1)
data = clustered

if hasattr(data, "columns"):
col_labels = data.columns.tolist()
col_labels = data.columns
data = data.values

if figsize is not None:
Expand All @@ -220,6 +247,7 @@ def plot_corrmat(
if not ax:
ax = plt.gca()

# If matrix is symmetric, keep only lower triangle
if symmetric:
data[np.triu(np.ones(data.shape, dtype=bool))] = np.nan

Expand Down Expand Up @@ -252,7 +280,7 @@ def plot_corrmat(
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
plt.setp(ax.get_xticklabels(), rotation=90, ha="right", va="center", rotation_mode="anchor")

# Turn spines off and create white grid.
ax.spines[:].set_visible(False)
Expand Down

0 comments on commit a373e37

Please sign in to comment.