diff --git a/Untitled12.ipynb b/Untitled12.ipynb new file mode 100644 index 0000000..6bae582 --- /dev/null +++ b/Untitled12.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "id": "a17c690c-5a86-43ec-88f9-73e493ca5aa2", + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install sentence-transformers umap-learn embetter" + ] + }, + { + "cell_type": "markdown", + "id": "2184ee7a-18a5-4edd-934f-6d3c16fe9c09", + "metadata": {}, + "source": [ + "# Intro to bulk from a notebook\n", + "\n", + "In an attempt to come to a quick demo, we ran some code beforehand that does some encoding. \n", + "\n", + "
\n", + " Show me the code.\n", + "\n", + "```python \n", + "import pandas as pd\n", + "from umap import UMAP\n", + "from sklearn.pipeline import make_pipeline \n", + "\n", + "# pip install \"embetter[text]\"\n", + "from embetter.text import SentenceEncoder\n", + "\n", + "# Build a sentence encoder pipeline with UMAP at the end.\n", + "enc = SentenceEncoder('all-MiniLM-L6-v2')\n", + "umap = UMAP()\n", + "\n", + "text_emb_pipeline = make_pipeline(\n", + " enc, umap\n", + ")\n", + "\n", + "# Load sentences\n", + "sentences = list(pd.read_csv(\"tests/data/text.csv\")['text'])\n", + "\n", + "# Calculate embeddings \n", + "X_tfm = text_emb_pipeline.fit_transform(sentences)\n", + "\n", + "# Write to disk. Note! Text column must be named \"text\"\n", + "df = pd.DataFrame({\"text\": sentences})\n", + "df['x'] = X_tfm[:, 0]\n", + "df['y'] = X_tfm[:, 1]\n", + "\n", + "X = enc.transform(sentences)\n", + "```\n", + "\n", + "This gives us a dataframe `df` that contains sentences, but also contains 2D UMAP representations of sentence embeddings. We also have a sentence encoder `enc` loaded and we also have access to our original embeddings `X`. Computing these can take a while on a CPU so we will store these on disk.\n", + "\n", + "```python\n", + "import numpy as np\n", + "\n", + "np.save(\"utils/X\", X)\n", + "np.save(\"utils/X_tfm\", X_tfm)\n", + "df.to_csv(\"utils/df.csv\", index=False)\n", + "```\n", + "\n", + "
\n", + "\n", + "\n", + "For now, all you need to know is that we have some files on disk with some precomputer embeddings and representations that we can use to make a nice interactive plot. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7e5ba2db-883f-4184-86dc-95d3560b2c24", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "0abb4012-9c2b-4c68-ba72-91f5f61ddd0b", + "metadata": {}, + "source": [ + "So lets load some data first." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "ae6a77ec-33ec-446d-8db9-147677c93492", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "X = np.load(\"utils/X.npy\")\n", + "X_tfm = np.load(\"utils/X_tfm.npy\")\n", + "df = pd.read_csv(\"utils/df.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "33f196ed-7e72-415b-b749-c5f8be20d99c", + "metadata": {}, + "source": [ + "Next, lets use these variables to conjure up a basic text explorer. This will allow us to quickly explore the clusters that appear in our data. You can hold the mouse cursor to go into selection mode and when you select items you will see a random subset appear on the right. You can resample from your selection by clicking the resample button." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "daaf355b-ac0c-4399-be72-647eea085a38", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "adf74d2814ba41e093e3c926e07b4556", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width…" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from bulk.widgets import BaseTextExplorer\n", + "\n", + "widget = BaseTextExplorer(df)\n", + "widget.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3301a2b2-1bec-400f-bc0d-09885e251178", + "metadata": {}, + "source": [ + "Being able to explore these clusters is neat, but it feels like we might more easily explore everything if we have some more tools at our disposal. In particular, we want to have an encoder around so that we may use queries in our embedded space. \n", + "\n", + "The UI below will allow us to explore much more interactively." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "9c48e600-59f8-44fc-acc6-a813b1547e45", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/vincent/Development/bulk/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from embetter.text import SentenceEncoder\n", + "\n", + "enc = SentenceEncoder('all-MiniLM-L6-v2')" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "c07e82b7-0735-4c12-bbd0-591efca09436", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e2cc738d482b4feb98a9eb7f3b45b28b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Pay attention here! The rows in df needs to align with the rows in X!\n", + "widget = BaseTextExplorer(df, X=X, encoder=enc)\n", + "widget.show()" + ] + }, + { + "cell_type": "markdown", + "id": "cc7abf74-5dd1-4d14-9d54-3b2422389472", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "Thanks to tools like ipywidget and anywidget, we can really start building some tools to make the notebook hard to beat as the go-to place for your data needs. My primary interest is to work on tools for data quality and being able to select datapoints in bulk feels like a great place to start. Maybe you can find an interesting subset to annotate first, maybe you get suprised when you see two distinct clusters that should be one. All that good stuff can happen in the notebook.\n", + "\n", + "More UI will follow, but this `BaseTextExplorer` feels like a nice place to start!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bulk/__init__.py b/bulk/__init__.py index e69de29..c63d54b 100644 --- a/bulk/__init__.py +++ b/bulk/__init__.py @@ -0,0 +1,62 @@ +import jscatter +from ipywidgets import HBox, VBox, HTML, Layout, Button, Text +from IPython.display import display +from sklearn.metrics.pairwise import cosine_similarity + +class BaseTextExplorer: + """ + Early preview of Jupyter Widget explorer. + """ + def __init__(self, dataf, X=None, encoder=None): + self.dataf = dataf + self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500) + self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px')) + self.sample_btn = Button(description='resample') + self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])]) + self.X = X + self.encoder = encoder + + if self.encoder and (self.X is not None): + self.text_input = Text(value='', placeholder='Type something', description='String:') + self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])]) + + def update_text(change): + X_tfm = encoder.transform([self.text_input.value]) + dists = cosine_similarity(X, X_tfm).reshape(1, -1) + self.dists = dists + norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min()) + self.scatter.color(by=norm_dists[0]) + self.scatter.size(by=norm_dists[0]) + + self.text_input.observe(update_text) + + self.scatter.widget.observe(lambda d: self.update(), ['selection']) + self.sample_btn.on_click(lambda d: self.update()) + + def show(self): + return self.elem + + def update(self): + if len(self.scatter.selection()) > 10: + texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"] + else: + texts = self.dataf.iloc[self.scatter.selection()]["text"] + self.html.value = ''.join([f'

{t}

' for t in texts]) + + def observe(self, func): + self.scatter.widget.observe(func, ['selection']) + + @property + def selected_idx(self): + return self.scatter.selection() + + @property + def selected_texts(self): + return list(self.dataf.iloc[self.selection_idx]["text"]) + + @property + def selected_dataframe(self): + return self.dataf.iloc[self.selection_idx] + + def _repr_html_(self): + return display(self.elem) diff --git a/bulk/widgets.py b/bulk/widgets.py new file mode 100644 index 0000000..5709175 --- /dev/null +++ b/bulk/widgets.py @@ -0,0 +1,64 @@ +import jscatter +import numpy as np +import pandas as pd +from ipywidgets import HBox, VBox, HTML, Layout, Button, Text +from IPython.display import display +from sklearn.metrics.pairwise import cosine_similarity + +class BaseTextExplorer: + """ + Interface for basic text exploration in embedded space. + """ + def __init__(self, dataf, X=None, encoder=None): + self.dataf = dataf + self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500) + self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px')) + self.sample_btn = Button(description='resample') + self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])]) + self.X = X + self.encoder = encoder + + if self.encoder and (self.X is not None): + self.text_input = Text(value='', placeholder='Type something', description='String:') + self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])]) + + def update_text(change): + X_tfm = encoder.transform([self.text_input.value]) + dists = cosine_similarity(X, X_tfm).reshape(1, -1) + self.dists = dists + norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min()) + self.scatter.color(by=norm_dists[0]) + self.scatter.size(by=norm_dists[0]) + + self.text_input.observe(update_text) + + self.scatter.widget.observe(lambda d: self.update(), ['selection']) + self.sample_btn.on_click(lambda d: self.update()) + + def show(self): + return self.elem + + def update(self): + if len(self.scatter.selection()) > 10: + texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"] + else: + texts = self.dataf.iloc[self.scatter.selection()]["text"] + self.html.value = ''.join([f'

{t}

' for t in texts]) + + def observe(self, func): + self.scatter.widget.observe(func, ['selection']) + + @property + def selected_idx(self): + return self.scatter.selection() + + @property + def selected_texts(self): + return list(self.dataf.iloc[self.selection_idx]["text"]) + + @property + def selected_dataframe(self): + return self.dataf.iloc[self.selection_idx] + + def _repr_html_(self): + return display(self.elem) diff --git a/setup.py b/setup.py index d841812..a256648 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ setup( name="bulk", - version="0.3.1", + version="0.3.2", packages=find_packages(), - install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1"], + install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1", "numpy<2", "jupyter-scatter", "scikit-learn"], extras_require={ "dev": ["pytest-playwright==0.3.0"], }, diff --git a/utils/X.npy b/utils/X.npy new file mode 100644 index 0000000..96b86e3 Binary files /dev/null and b/utils/X.npy differ diff --git a/utils/X.npz.npy b/utils/X.npz.npy new file mode 100644 index 0000000..f338924 Binary files /dev/null and b/utils/X.npz.npy differ diff --git a/utils/X_tfm.npy b/utils/X_tfm.npy new file mode 100644 index 0000000..032848e Binary files /dev/null and b/utils/X_tfm.npy differ diff --git a/utils/X_tfm.npz.npy b/utils/X_tfm.npz.npy new file mode 100644 index 0000000..0dcd9cc Binary files /dev/null and b/utils/X_tfm.npz.npy differ