Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clustering feature in napari deeplabcut #38

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ Suggested workflows, depending on the image folder contents:
and start drawing rectangles over the images. Masks and rectangle vertices are saved as described in [Save Layers](#save-layers).
Note that masks can be reloaded and edited at a later stage by dropping the `vertices.csv` file onto the canvas.

5. **Detect Outliers to Refine Labels**
Open napari as described in [Usage](#usage) and open the `CollectedData_<ScorerName>.h5` file. Click on the button cluster and wait a few seconds. It will show a new layer with the cluster. You can click on a point and see the image on the right with the keypoints. If you decided to refine that frame, click show img and refine them. You can go back to the cluster layer by clicking on close img and refine another image. When you're done, you need to do ctl s to save it. And now you can retrain the network!


### Workflow flowchart

```mermaid
Expand Down
1 change: 1 addition & 0 deletions src/napari_deeplabcut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
get_config_reader,
)
from ._writer import write_hdf, write_masks

3 changes: 2 additions & 1 deletion src/napari_deeplabcut/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def read_hdf(filename: str) -> List[LayerData]:
if isinstance(temp.index, pd.MultiIndex):
temp.index = [os.path.join(*row) for row in temp.index]
df = (
temp.stack(["individuals", "bodyparts"])
temp.stack(["individuals", "bodyparts"])#, dropna=False)
.reindex(header.individuals, level="individuals")
.reindex(header.bodyparts, level="bodyparts")
.reset_index()
)
#df.fillna(0, inplace=True)
nrows = df.shape[0]
data = np.empty((nrows, 3))
image_paths = df["level_0"]
Expand Down
146 changes: 145 additions & 1 deletion src/napari_deeplabcut/_widgets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
import os
from collections import defaultdict
from functools import partial
import numpy as np
import pandas as pd
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from types import MethodType
from typing import Optional, Sequence, Union
from napari.layers import Image, Points
from collections import defaultdict, namedtuple
from copy import deepcopy
from datetime import datetime
Expand Down Expand Up @@ -45,9 +54,38 @@
QStyleOption,
QVBoxLayout,
QWidget,
QPushButton,
)

from napari_deeplabcut.kmeans import cluster_data
from napari_deeplabcut import keypoints
from napari_deeplabcut.misc import to_os_dir_sep, find_project_name


class Worker(QtCore.QObject):
started = QtCore.Signal()
finished = QtCore.Signal()
value = QtCore.Signal(object)

def __init__(self, func):
super().__init__()
self.func = func

def run(self):
out = self.func()
self.value.emit(out)
self.finished.emit()


def move_to_separate_thread(func):
thread = QtCore.QThread()
worker = Worker(func)
worker.moveToThread(thread)
thread.started.connect(worker.run)
worker.finished.connect(thread.quit)
worker.finished.connect(worker.deleteLater)
worker.finished.connect(thread.deleteLater)
return worker, thread
from napari_deeplabcut._reader import _load_config
from napari_deeplabcut._writer import _write_config, _write_image, _form_df
from napari_deeplabcut.misc import (
Expand Down Expand Up @@ -538,7 +576,6 @@ def __init__(self, napari_viewer):
self.viewer = napari_viewer
self.viewer.layers.events.inserted.connect(self.on_insert)
self.viewer.layers.events.removed.connect(self.on_remove)

self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType(
_get_and_try_preferred_reader,
self.viewer.window.qt_viewer,
Expand Down Expand Up @@ -779,6 +816,108 @@ def _store_crop_coordinates(self, *args):
_write_config(config_path, cfg)
break

self.add_clustering_buttons()

# Initialize an empty canvas onto which to draw the images
self.fig = Figure(tight_layout=True, dpi=100)
self.fig.patch.set_facecolor("None")
self.ax = self.fig.add_subplot(111)
self.ax.invert_yaxis()
self.ax.set_axis_off()
self._im = None
self._scatter = self.ax.scatter([], [])
self.canvas = FigureCanvas(self.fig)

self.show()

def add_clustering_buttons(self):
layout = QHBoxLayout()
btn_cluster = QPushButton('cluster pose', self)
btn_cluster.clicked.connect(self.on_click)
btn_show = QPushButton('show img', self)
btn_show.clicked.connect(self.on_click_show_img)
btn_close = QPushButton('close img', self)
btn_close.clicked.connect(self.on_click_close_img)
layout.addWidget(btn_cluster)
layout.addWidget(btn_show)
layout.addWidget(btn_close)
self._layout.addLayout(layout)

def _show_clusters(self, input_):
points, names = input_
points[:, [0, 1]] = points[:, [1, 0]]
colors = points[:, 2] + 1

dict_prop_points = {'colorn': colors, 'frame': names}
clust_layer = self.viewer.add_points(
points[:, :2],
size=8,
name='cluster',
features=dict_prop_points,
face_color='colorn',
face_colormap='plasma',
)
clust_layer.mode = 'select'

self.viewer.window.add_dock_widget(self.canvas, name='frames')
self.viewer.layers[0].visible = False

self._df = pd.read_hdf(self.viewer.layers[0].source.path)
self._df.index = ['/'.join(row) for row in list(self._df.index)]

root = self.viewer.layers[0].metadata['root']
filenames = list(self.viewer.layers[0].metadata['paths'])
project_name = find_project_name(root)
project_path = os.path.join(root.split(project_name)[0], project_name)

@clust_layer.mouse_drag_callbacks.append
def get_event(clust_layer, event):
inds = list(clust_layer.selected_data)
if not inds:
return

if len(inds) > 1:
self.viewer.status = 'Please select only one data point.'
return

ind = inds[0]
filename = clust_layer.properties['frame'][ind]
bpts = self._df.loc[filename].to_numpy().reshape((-1, 2))
self.step = filenames.index(filename)

with Image_.open(os.path.join(project_path, filename)) as img:
im = np.asarray(img)
if self._im is None:
self._im = self.ax.imshow(im)
else:
self._im.set_data(im)
self._scatter.set_offsets(bpts)
self.canvas.draw()

def on_click(self):
layer = self.viewer.layers.selection.active
if not isinstance(layer, Points):
print("Only Points layers can be clustered.")
return

func = partial(cluster_data, layer)
self.worker, self.thread = move_to_separate_thread(func)
self.worker.value.connect(self._show_clusters)
self.thread.start()

def on_click_show_img(self):
self.viewer.layers[0].visible = True
self.viewer.layers[1].visible = False
self.viewer.dims.set_current_step(0, self.step)
self.viewer.add_image(self._im.get_array(), name='image refine label')
self.viewer.layers.move_selected(0, 2)

def on_click_close_img(self):
self.viewer.layers.remove('image refine label')
self.viewer.layers.move_selected(0, 1)
self.viewer.layers[0].visible = False
self.viewer.layers[1].visible = True

def _form_dropdown_menus(self, store):
menu = KeypointsDropdownMenu(store)
self.viewer.dims.events.current_step.connect(
Expand Down Expand Up @@ -871,6 +1010,10 @@ def _remap_frame_indices(self, layer):

def on_insert(self, event):
layer = event.source[-1]
# FIXME Is the following necessary?
if any(s in str(layer) for s in ('cluster', 'refine')):
return

if isinstance(layer, Image):
paths = layer.metadata.get("paths")
if paths is None: # Then it's a video file
Expand Down Expand Up @@ -1091,6 +1234,7 @@ def __init__(
):
super().__init__(parent)
self.store = store

self.store.layer.events.current_properties.connect(self.update_menus)
self._locked = False

Expand Down
2 changes: 1 addition & 1 deletion src/napari_deeplabcut/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ def write_masks(foldername, data, metadata):
output_path = filename.format(os.path.splitext(image_name)[0], shape_inds[n])
_write_image(mask, output_path)
napari_write_shapes(os.path.join(folder, "vertices.csv"), data, metadata)
return folder
return folder
45 changes: 45 additions & 0 deletions src/napari_deeplabcut/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from napari_deeplabcut._writer import _conv_layer_to_df
from napari_deeplabcut.misc import DLCHeader


def _cluster(data):
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(data)

# putting components in a dataframe for later
PCA_components = pd.DataFrame(principalComponents)

dbscan=DBSCAN(eps=9.7, min_samples=20, algorithm='ball_tree', metric='minkowski', leaf_size=90, p=2)

# fit - perform DBSCAN clustering from features, or distance matrix.
dbscan = dbscan.fit(PCA_components)
cluster1 = dbscan.labels_

return PCA_components, cluster1


def cluster_data(points_layer):
df = _conv_layer_to_df(
points_layer.data, points_layer.metadata, points_layer.properties
)
try:
df = df.drop('single', axis=1, level='individuals')
except KeyError:
pass
df.dropna(inplace=True)
header = DLCHeader(df.columns)
try:
df = df.stack('individuals').droplevel('individuals')
except KeyError:
pass
df.index = ['/'.join(row) for row in df.index]
xy = df.to_numpy().reshape((-1, len(header.bodyparts), 2))
# TODO Normalize dists by longest length?
dists = np.vstack([pdist(data, "euclidean") for data in xy])
points = np.c_[_cluster(dists)] # x, y, label
return points, list(df.index)
9 changes: 9 additions & 0 deletions src/napari_deeplabcut/misc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from __future__ import annotations

import os
import re
from enum import Enum, EnumMeta
from itertools import cycle
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from napari.utils import colormaps


def find_project_name(s):
pat = re.compile('.+-.+-\d{4}-\d{1,2}-\d{1,2}')
for part in Path(s).parts[::-1]:
if pat.search(part):
return part


def unsorted_unique(array: Sequence) -> np.ndarray:
"""Return the unsorted unique elements of an array."""
_, inds = np.unique(array, return_index=True)
Expand Down
9 changes: 9 additions & 0 deletions src/napari_deeplabcut/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ contributions:
- id: napari-deeplabcut.make_keypoint_controls
python_name: napari_deeplabcut._widgets:KeypointControls
title: Make keypoint controls

readers:
- command: napari-deeplabcut.get_hdf_reader
accepts_directories: false
Expand All @@ -42,6 +43,7 @@ contributions:
- command: napari-deeplabcut.get_folder_parser
accepts_directories: true
filename_patterns: ['*']

writers:
- command: napari-deeplabcut.write_hdf
layer_types: ["points{1}"]
Expand All @@ -52,3 +54,10 @@ contributions:
widgets:
- command: napari-deeplabcut.make_keypoint_controls
display_name: Keypoint controls
kmeans:
- command: napari-deeplabcut.get_hdf_reader1
accepts_directories: false
filename_patterns: ['*.h5']
- command: napari-deeplabcut.get_folder_parser1
accepts_directories: true
filename_patterns: ['*']