-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #69 from koaning/notebookui
v1 with notebook support
- Loading branch information
Showing
8 changed files
with
391 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "a17c690c-5a86-43ec-88f9-73e493ca5aa2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# %pip install sentence-transformers umap-learn embetter" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2184ee7a-18a5-4edd-934f-6d3c16fe9c09", | ||
"metadata": {}, | ||
"source": [ | ||
"# Intro to bulk from a notebook\n", | ||
"\n", | ||
"In an attempt to come to a quick demo, we ran some code beforehand that does some encoding. \n", | ||
"\n", | ||
"<details>\n", | ||
" <summary><b>Show me the code.</b></summary>\n", | ||
"\n", | ||
"```python \n", | ||
"import pandas as pd\n", | ||
"from umap import UMAP\n", | ||
"from sklearn.pipeline import make_pipeline \n", | ||
"\n", | ||
"# pip install \"embetter[text]\"\n", | ||
"from embetter.text import SentenceEncoder\n", | ||
"\n", | ||
"# Build a sentence encoder pipeline with UMAP at the end.\n", | ||
"enc = SentenceEncoder('all-MiniLM-L6-v2')\n", | ||
"umap = UMAP()\n", | ||
"\n", | ||
"text_emb_pipeline = make_pipeline(\n", | ||
" enc, umap\n", | ||
")\n", | ||
"\n", | ||
"# Load sentences\n", | ||
"sentences = list(pd.read_csv(\"tests/data/text.csv\")['text'])\n", | ||
"\n", | ||
"# Calculate embeddings \n", | ||
"X_tfm = text_emb_pipeline.fit_transform(sentences)\n", | ||
"\n", | ||
"# Write to disk. Note! Text column must be named \"text\"\n", | ||
"df = pd.DataFrame({\"text\": sentences})\n", | ||
"df['x'] = X_tfm[:, 0]\n", | ||
"df['y'] = X_tfm[:, 1]\n", | ||
"\n", | ||
"X = enc.transform(sentences)\n", | ||
"```\n", | ||
"\n", | ||
"This gives us a dataframe `df` that contains sentences, but also contains 2D UMAP representations of sentence embeddings. We also have a sentence encoder `enc` loaded and we also have access to our original embeddings `X`. Computing these can take a while on a CPU so we will store these on disk.\n", | ||
"\n", | ||
"```python\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"np.save(\"utils/X\", X)\n", | ||
"np.save(\"utils/X_tfm\", X_tfm)\n", | ||
"df.to_csv(\"utils/df.csv\", index=False)\n", | ||
"```\n", | ||
"\n", | ||
"</details>\n", | ||
"\n", | ||
"\n", | ||
"For now, all you need to know is that we have some files on disk with some precomputer embeddings and representations that we can use to make a nice interactive plot. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "7e5ba2db-883f-4184-86dc-95d3560b2c24", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0abb4012-9c2b-4c68-ba72-91f5f61ddd0b", | ||
"metadata": {}, | ||
"source": [ | ||
"So lets load some data first." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"id": "ae6a77ec-33ec-446d-8db9-147677c93492", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"X = np.load(\"utils/X.npy\")\n", | ||
"X_tfm = np.load(\"utils/X_tfm.npy\")\n", | ||
"df = pd.read_csv(\"utils/df.csv\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "33f196ed-7e72-415b-b749-c5f8be20d99c", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, lets use these variables to conjure up a basic text explorer. This will allow us to quickly explore the clusters that appear in our data. You can hold the mouse cursor to go into selection mode and when you select items you will see a random subset appear on the right. You can resample from your selection by clicking the resample button." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "daaf355b-ac0c-4399-be72-647eea085a38", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "adf74d2814ba41e093e3c926e07b4556", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width…" | ||
] | ||
}, | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from bulk.widgets import BaseTextExplorer\n", | ||
"\n", | ||
"widget = BaseTextExplorer(df)\n", | ||
"widget.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3301a2b2-1bec-400f-bc0d-09885e251178", | ||
"metadata": {}, | ||
"source": [ | ||
"Being able to explore these clusters is neat, but it feels like we might more easily explore everything if we have some more tools at our disposal. In particular, we want to have an encoder around so that we may use queries in our embedded space. \n", | ||
"\n", | ||
"The UI below will allow us to explore much more interactively." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"id": "9c48e600-59f8-44fc-acc6-a813b1547e45", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/vincent/Development/bulk/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from embetter.text import SentenceEncoder\n", | ||
"\n", | ||
"enc = SentenceEncoder('all-MiniLM-L6-v2')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"id": "c07e82b7-0735-4c12-bbd0-591efca09436", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e2cc738d482b4feb98a9eb7f3b45b28b", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…" | ||
] | ||
}, | ||
"execution_count": 33, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Pay attention here! The rows in df needs to align with the rows in X!\n", | ||
"widget = BaseTextExplorer(df, X=X, encoder=enc)\n", | ||
"widget.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cc7abf74-5dd1-4d14-9d54-3b2422389472", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"Thanks to tools like ipywidget and anywidget, we can really start building some tools to make the notebook hard to beat as the go-to place for your data needs. My primary interest is to work on tools for data quality and being able to select datapoints in bulk feels like a great place to start. Maybe you can find an interesting subset to annotate first, maybe you get suprised when you see two distinct clusters that should be one. All that good stuff can happen in the notebook.\n", | ||
"\n", | ||
"More UI will follow, but this `BaseTextExplorer` feels like a nice place to start!" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import jscatter | ||
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text | ||
from IPython.display import display | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
class BaseTextExplorer: | ||
""" | ||
Early preview of Jupyter Widget explorer. | ||
""" | ||
def __init__(self, dataf, X=None, encoder=None): | ||
self.dataf = dataf | ||
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500) | ||
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px')) | ||
self.sample_btn = Button(description='resample') | ||
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])]) | ||
self.X = X | ||
self.encoder = encoder | ||
|
||
if self.encoder and (self.X is not None): | ||
self.text_input = Text(value='', placeholder='Type something', description='String:') | ||
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])]) | ||
|
||
def update_text(change): | ||
X_tfm = encoder.transform([self.text_input.value]) | ||
dists = cosine_similarity(X, X_tfm).reshape(1, -1) | ||
self.dists = dists | ||
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min()) | ||
self.scatter.color(by=norm_dists[0]) | ||
self.scatter.size(by=norm_dists[0]) | ||
|
||
self.text_input.observe(update_text) | ||
|
||
self.scatter.widget.observe(lambda d: self.update(), ['selection']) | ||
self.sample_btn.on_click(lambda d: self.update()) | ||
|
||
def show(self): | ||
return self.elem | ||
|
||
def update(self): | ||
if len(self.scatter.selection()) > 10: | ||
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"] | ||
else: | ||
texts = self.dataf.iloc[self.scatter.selection()]["text"] | ||
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts]) | ||
|
||
def observe(self, func): | ||
self.scatter.widget.observe(func, ['selection']) | ||
|
||
@property | ||
def selected_idx(self): | ||
return self.scatter.selection() | ||
|
||
@property | ||
def selected_texts(self): | ||
return list(self.dataf.iloc[self.selection_idx]["text"]) | ||
|
||
@property | ||
def selected_dataframe(self): | ||
return self.dataf.iloc[self.selection_idx] | ||
|
||
def _repr_html_(self): | ||
return display(self.elem) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import jscatter | ||
import numpy as np | ||
import pandas as pd | ||
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text | ||
from IPython.display import display | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
class BaseTextExplorer: | ||
""" | ||
Interface for basic text exploration in embedded space. | ||
""" | ||
def __init__(self, dataf, X=None, encoder=None): | ||
self.dataf = dataf | ||
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500) | ||
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px')) | ||
self.sample_btn = Button(description='resample') | ||
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])]) | ||
self.X = X | ||
self.encoder = encoder | ||
|
||
if self.encoder and (self.X is not None): | ||
self.text_input = Text(value='', placeholder='Type something', description='String:') | ||
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])]) | ||
|
||
def update_text(change): | ||
X_tfm = encoder.transform([self.text_input.value]) | ||
dists = cosine_similarity(X, X_tfm).reshape(1, -1) | ||
self.dists = dists | ||
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min()) | ||
self.scatter.color(by=norm_dists[0]) | ||
self.scatter.size(by=norm_dists[0]) | ||
|
||
self.text_input.observe(update_text) | ||
|
||
self.scatter.widget.observe(lambda d: self.update(), ['selection']) | ||
self.sample_btn.on_click(lambda d: self.update()) | ||
|
||
def show(self): | ||
return self.elem | ||
|
||
def update(self): | ||
if len(self.scatter.selection()) > 10: | ||
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"] | ||
else: | ||
texts = self.dataf.iloc[self.scatter.selection()]["text"] | ||
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts]) | ||
|
||
def observe(self, func): | ||
self.scatter.widget.observe(func, ['selection']) | ||
|
||
@property | ||
def selected_idx(self): | ||
return self.scatter.selection() | ||
|
||
@property | ||
def selected_texts(self): | ||
return list(self.dataf.iloc[self.selection_idx]["text"]) | ||
|
||
@property | ||
def selected_dataframe(self): | ||
return self.dataf.iloc[self.selection_idx] | ||
|
||
def _repr_html_(self): | ||
return display(self.elem) |
Oops, something went wrong.