Skip to content

Commit

Permalink
Merge pull request #69 from koaning/notebookui
Browse files Browse the repository at this point in the history
v1 with notebook support
  • Loading branch information
koaning authored Aug 23, 2024
2 parents f372531 + 0044b19 commit d595fcd
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 2 deletions.
263 changes: 263 additions & 0 deletions Untitled12.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "a17c690c-5a86-43ec-88f9-73e493ca5aa2",
"metadata": {},
"outputs": [],
"source": [
"# %pip install sentence-transformers umap-learn embetter"
]
},
{
"cell_type": "markdown",
"id": "2184ee7a-18a5-4edd-934f-6d3c16fe9c09",
"metadata": {},
"source": [
"# Intro to bulk from a notebook\n",
"\n",
"In an attempt to come to a quick demo, we ran some code beforehand that does some encoding. \n",
"\n",
"<details>\n",
" <summary><b>Show me the code.</b></summary>\n",
"\n",
"```python \n",
"import pandas as pd\n",
"from umap import UMAP\n",
"from sklearn.pipeline import make_pipeline \n",
"\n",
"# pip install \"embetter[text]\"\n",
"from embetter.text import SentenceEncoder\n",
"\n",
"# Build a sentence encoder pipeline with UMAP at the end.\n",
"enc = SentenceEncoder('all-MiniLM-L6-v2')\n",
"umap = UMAP()\n",
"\n",
"text_emb_pipeline = make_pipeline(\n",
" enc, umap\n",
")\n",
"\n",
"# Load sentences\n",
"sentences = list(pd.read_csv(\"tests/data/text.csv\")['text'])\n",
"\n",
"# Calculate embeddings \n",
"X_tfm = text_emb_pipeline.fit_transform(sentences)\n",
"\n",
"# Write to disk. Note! Text column must be named \"text\"\n",
"df = pd.DataFrame({\"text\": sentences})\n",
"df['x'] = X_tfm[:, 0]\n",
"df['y'] = X_tfm[:, 1]\n",
"\n",
"X = enc.transform(sentences)\n",
"```\n",
"\n",
"This gives us a dataframe `df` that contains sentences, but also contains 2D UMAP representations of sentence embeddings. We also have a sentence encoder `enc` loaded and we also have access to our original embeddings `X`. Computing these can take a while on a CPU so we will store these on disk.\n",
"\n",
"```python\n",
"import numpy as np\n",
"\n",
"np.save(\"utils/X\", X)\n",
"np.save(\"utils/X_tfm\", X_tfm)\n",
"df.to_csv(\"utils/df.csv\", index=False)\n",
"```\n",
"\n",
"</details>\n",
"\n",
"\n",
"For now, all you need to know is that we have some files on disk with some precomputer embeddings and representations that we can use to make a nice interactive plot. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7e5ba2db-883f-4184-86dc-95d3560b2c24",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"id": "0abb4012-9c2b-4c68-ba72-91f5f61ddd0b",
"metadata": {},
"source": [
"So lets load some data first."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ae6a77ec-33ec-446d-8db9-147677c93492",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"X = np.load(\"utils/X.npy\")\n",
"X_tfm = np.load(\"utils/X_tfm.npy\")\n",
"df = pd.read_csv(\"utils/df.csv\")"
]
},
{
"cell_type": "markdown",
"id": "33f196ed-7e72-415b-b749-c5f8be20d99c",
"metadata": {},
"source": [
"Next, lets use these variables to conjure up a basic text explorer. This will allow us to quickly explore the clusters that appear in our data. You can hold the mouse cursor to go into selection mode and when you select items you will see a random subset appear on the right. You can resample from your selection by clicking the resample button."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "daaf355b-ac0c-4399-be72-647eea085a38",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adf74d2814ba41e093e3c926e07b4556",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width…"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from bulk.widgets import BaseTextExplorer\n",
"\n",
"widget = BaseTextExplorer(df)\n",
"widget.show()"
]
},
{
"cell_type": "markdown",
"id": "3301a2b2-1bec-400f-bc0d-09885e251178",
"metadata": {},
"source": [
"Being able to explore these clusters is neat, but it feels like we might more easily explore everything if we have some more tools at our disposal. In particular, we want to have an encoder around so that we may use queries in our embedded space. \n",
"\n",
"The UI below will allow us to explore much more interactively."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "9c48e600-59f8-44fc-acc6-a813b1547e45",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/vincent/Development/bulk/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
"from embetter.text import SentenceEncoder\n",
"\n",
"enc = SentenceEncoder('all-MiniLM-L6-v2')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "c07e82b7-0735-4c12-bbd0-591efca09436",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2cc738d482b4feb98a9eb7f3b45b28b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Pay attention here! The rows in df needs to align with the rows in X!\n",
"widget = BaseTextExplorer(df, X=X, encoder=enc)\n",
"widget.show()"
]
},
{
"cell_type": "markdown",
"id": "cc7abf74-5dd1-4d14-9d54-3b2422389472",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Thanks to tools like ipywidget and anywidget, we can really start building some tools to make the notebook hard to beat as the go-to place for your data needs. My primary interest is to work on tools for data quality and being able to select datapoints in bulk feels like a great place to start. Maybe you can find an interesting subset to annotate first, maybe you get suprised when you see two distinct clusters that should be one. All that good stuff can happen in the notebook.\n",
"\n",
"More UI will follow, but this `BaseTextExplorer` feels like a nice place to start!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
62 changes: 62 additions & 0 deletions bulk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import jscatter
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text
from IPython.display import display
from sklearn.metrics.pairwise import cosine_similarity

class BaseTextExplorer:
"""
Early preview of Jupyter Widget explorer.
"""
def __init__(self, dataf, X=None, encoder=None):
self.dataf = dataf
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
self.sample_btn = Button(description='resample')
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])])
self.X = X
self.encoder = encoder

if self.encoder and (self.X is not None):
self.text_input = Text(value='', placeholder='Type something', description='String:')
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])])

def update_text(change):
X_tfm = encoder.transform([self.text_input.value])
dists = cosine_similarity(X, X_tfm).reshape(1, -1)
self.dists = dists
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min())
self.scatter.color(by=norm_dists[0])
self.scatter.size(by=norm_dists[0])

self.text_input.observe(update_text)

self.scatter.widget.observe(lambda d: self.update(), ['selection'])
self.sample_btn.on_click(lambda d: self.update())

def show(self):
return self.elem

def update(self):
if len(self.scatter.selection()) > 10:
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
else:
texts = self.dataf.iloc[self.scatter.selection()]["text"]
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

def observe(self, func):
self.scatter.widget.observe(func, ['selection'])

@property
def selected_idx(self):
return self.scatter.selection()

@property
def selected_texts(self):
return list(self.dataf.iloc[self.selection_idx]["text"])

@property
def selected_dataframe(self):
return self.dataf.iloc[self.selection_idx]

def _repr_html_(self):
return display(self.elem)
64 changes: 64 additions & 0 deletions bulk/widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import jscatter
import numpy as np
import pandas as pd
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text
from IPython.display import display
from sklearn.metrics.pairwise import cosine_similarity

class BaseTextExplorer:
"""
Interface for basic text exploration in embedded space.
"""
def __init__(self, dataf, X=None, encoder=None):
self.dataf = dataf
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
self.sample_btn = Button(description='resample')
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])])
self.X = X
self.encoder = encoder

if self.encoder and (self.X is not None):
self.text_input = Text(value='', placeholder='Type something', description='String:')
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])])

def update_text(change):
X_tfm = encoder.transform([self.text_input.value])
dists = cosine_similarity(X, X_tfm).reshape(1, -1)
self.dists = dists
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min())
self.scatter.color(by=norm_dists[0])
self.scatter.size(by=norm_dists[0])

self.text_input.observe(update_text)

self.scatter.widget.observe(lambda d: self.update(), ['selection'])
self.sample_btn.on_click(lambda d: self.update())

def show(self):
return self.elem

def update(self):
if len(self.scatter.selection()) > 10:
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
else:
texts = self.dataf.iloc[self.scatter.selection()]["text"]
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

def observe(self, func):
self.scatter.widget.observe(func, ['selection'])

@property
def selected_idx(self):
return self.scatter.selection()

@property
def selected_texts(self):
return list(self.dataf.iloc[self.selection_idx]["text"])

@property
def selected_dataframe(self):
return self.dataf.iloc[self.selection_idx]

def _repr_html_(self):
return display(self.elem)
Loading

0 comments on commit d595fcd

Please sign in to comment.