Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding pyright #7

Merged
merged 6 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Pyright

on:
pull_request:
branches: [main]
push:
branches: [main]

env:
WORKING_DIRECTORY: "."
PYRIGHT_OUTPUT_FILENAME: "pyright.log"

jobs:
Pyright:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.9"]

steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
enable-cache: true
- name: Set up Python
run: uv python install ${{ matrix.python-version }}
- name: Install the project
run: uv sync --all-extras --dev
- name: Run pyright
run: uv run pyright
6 changes: 4 additions & 2 deletions kura/base_classes/cluster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from kura.types import Cluster, ConversationSummary
from kura.types import ConversationSummary, Cluster


class BaseClusterModel(ABC):
@abstractmethod
def cluster_summaries(self, summaries: list[ConversationSummary]) -> list[Cluster]:
async def cluster_summaries(
self, summaries: list[ConversationSummary]
) -> list[Cluster]:
pass

# TODO : Add abstract method for hooks here once we start supporting it
6 changes: 4 additions & 2 deletions kura/base_classes/dimensionality.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod

from kura.types import Cluster
from kura.types import Cluster, ProjectedCluster


class BaseDimensionalityReduction(ABC):
@abstractmethod
async def reduce_dimensionality(self, clusters: list[Cluster]) -> list[Cluster]:
async def reduce_dimensionality(
self, clusters: list[Cluster]
) -> list[ProjectedCluster]:
"""
This reduces the dimensionality of the individual clusters that we've created so we can visualise them in a lower dimension
"""
Expand Down
2 changes: 1 addition & 1 deletion kura/base_classes/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

class BaseEmbeddingModel(ABC):
@abstractmethod
async def embed(self, text: str) -> list[float]:
async def embed(self, text: str, sem: Semaphore) -> list[float]:
"""Embed a single text into a list of floats"""
pass
2 changes: 1 addition & 1 deletion kura/base_classes/meta_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

class BaseMetaClusterModel(ABC):
@abstractmethod
def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:
async def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:
pass
4 changes: 3 additions & 1 deletion kura/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ async def analyse_conversations(conversation_data: ConversationData):

if not clusters_file.exists():
kura = Kura(
checkpoint_dir=Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"])),
checkpoint_dir=str(
Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"]))
),
conversations=conversations,
)
await kura.cluster_conversations()
Expand Down
18 changes: 9 additions & 9 deletions kura/cli/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def generate_cumulative_chart_data(conversations: List[Conversation]) -> dict:
for x, y in zip(
weekly_df["week_start"].tolist(), weekly_df["cumulative_words"].tolist()
)
]
] # pyright: ignore


def generate_messages_per_chat_data(conversations: List[Conversation]) -> dict:
Expand All @@ -52,9 +52,9 @@ def generate_messages_per_chat_data(conversations: List[Conversation]) -> dict:
df = pd.DataFrame(messages_data)
df["week_start"] = df["datetime"].dt.to_period("W-MON").dt.start_time

weekly_messages = df.groupby("week_start").size().reset_index(name="message_count")
weekly_messages = df.groupby("week_start").size().reset_index(name="message_count") # pyright: ignore
weekly_chats = (
df.groupby("week_start")["chat_id"].nunique().reset_index(name="chat_count")
df.groupby("week_start")["chat_id"].nunique().reset_index(name="chat_count") # pyright: ignore
)

weekly_df = pd.merge(weekly_messages, weekly_chats, on="week_start")
Expand All @@ -66,7 +66,7 @@ def generate_messages_per_chat_data(conversations: List[Conversation]) -> dict:
for x, y in zip(
weekly_df["week_start"].tolist(), weekly_df["avg_messages"].tolist()
)
]
] # pyright: ignore


def generate_messages_per_week_data(conversations: List[Conversation]) -> dict:
Expand All @@ -85,18 +85,18 @@ def generate_messages_per_week_data(conversations: List[Conversation]) -> dict:
df = pd.DataFrame(messages_data)
df["week_start"] = df["datetime"].dt.to_period("W-MON").dt.start_time

weekly_messages = df.groupby("week_start").size().reset_index(name="message_count")
weekly_messages = df.groupby("week_start").size().reset_index(name="message_count") # pyright: ignore
weekly_messages["week_start"] = weekly_messages["week_start"].dt.strftime(
"%Y-%m-%d"
)
) # pyright: ignore

return [
{"x": x, "y": y}
for x, y in zip(
weekly_messages["week_start"].tolist(),
weekly_messages["message_count"].tolist(),
)
]
] # pyright: ignore


def generate_new_chats_per_week_data(conversations: List[Conversation]) -> dict:
Expand All @@ -113,7 +113,7 @@ def generate_new_chats_per_week_data(conversations: List[Conversation]) -> dict:
chat_starts["datetime"].dt.to_period("W-MON").dt.start_time
)
weekly_chats = (
chat_starts.groupby("week_start").size().reset_index(name="chat_count")
chat_starts.groupby("week_start").size().reset_index(name="chat_count") # pyright: ignore
)
weekly_chats["week_start"] = weekly_chats["week_start"].dt.strftime("%Y-%m-%d")

Expand All @@ -122,4 +122,4 @@ def generate_new_chats_per_week_data(conversations: List[Conversation]) -> dict:
for x, y in zip(
weekly_chats["week_start"].tolist(), weekly_chats["chat_count"].tolist()
)
]
] # pyright: ignore
11 changes: 7 additions & 4 deletions kura/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,17 @@ async def generate_cluster(
)

async def cluster_summaries(
self, items: list[ConversationSummary]
) -> dict[int, list[ConversationSummary]]:
self, summaries: list[ConversationSummary]
) -> list[Cluster]:
if os.path.exists(os.path.join(self.checkpoint_dir, self.checkpoint_file)):
return self.load_checkpoint()

sem = Semaphore(self.max_concurrent_requests)
embeddings: list[list[float]] = await tqdm_asyncio.gather(
*[self.embedding_model.embed(item.summary, sem) for item in items],
*[
self.embedding_model.embed(text=item.summary, sem=sem)
for item in summaries
],
desc="Embedding Summaries",
)
cluster_id_to_summaries = self.clustering_method.cluster(
Expand All @@ -143,7 +146,7 @@ async def cluster_summaries(
"item": item,
"embedding": embedding,
}
for item, embedding in zip(items, embeddings)
for item, embedding in zip(summaries, embeddings)
]
)
clusters: list[Cluster] = await tqdm_asyncio.gather(
Expand Down
6 changes: 4 additions & 2 deletions kura/dimensionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import asyncio
import os
from numpy.typing import NDArray
from numpy import float64


class HDBUMAP(BaseDimensionalityReduction):
Expand Down Expand Up @@ -62,8 +64,8 @@ async def reduce_dimensionality(
description=cluster.description,
chat_ids=cluster.chat_ids,
parent_id=cluster.parent_id,
x_coord=float(reduced_embeddings[i][0]),
y_coord=float(reduced_embeddings[i][1]),
x_coord=float(reduced_embeddings[i][0]), # pyright: ignore
y_coord=float(reduced_embeddings[i][1]), # pyright: ignore
level=0 if cluster.parent_id is None else 1,
)
res.append(projected)
Expand Down
10 changes: 8 additions & 2 deletions kura/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ def cluster(self, items: list[T]) -> dict[int, list[T]]:

- its relevant embedding stored in the "embedding" key.
- the item itself stored in the "item" key.

{
"embedding": list[float],
"item": any,
}
"""
embeddings = [item["embedding"] for item in items]
data: list[T] = [item["item"] for item in items]

embeddings = [item["embedding"] for item in items] # pyright: ignore
data: list[T] = [item["item"] for item in items] # pyright: ignore
n_clusters = math.ceil(len(data) / self.clusters_per_group)

X = np.array(embeddings)
Expand Down
12 changes: 6 additions & 6 deletions kura/kura.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kura.dimensionality import HDBUMAP
from kura.types import Conversation, Message, Cluster
from kura.types import Conversation, Cluster
from kura.embedding import OpenAIEmbeddingModel
from kura.summarisation import SummaryModel
from kura.meta_cluster import MetaClusterModel
Expand Down Expand Up @@ -29,11 +29,11 @@ def __init__(
cluster_checkpoint_name: str = "clusters.json",
meta_cluster_checkpoint_name: str = "meta_clusters.json",
):
# Override checkpoint dirs so that they're the same for the models
summarisation_model.checkpoint_dir = checkpoint_dir
cluster_model.checkpoint_dir = checkpoint_dir
meta_cluster_model.checkpoint_dir = checkpoint_dir
dimensionality_reduction.checkpoint_dir = checkpoint_dir
# TODO: Manage Checkpoints within Kura class itself so we can directly disable checkpointing easily
summarisation_model.checkpoint_dir = checkpoint_dir # pyright: ignore
cluster_model.checkpoint_dir = checkpoint_dir # pyright: ignore
meta_cluster_model.checkpoint_dir = checkpoint_dir # pyright: ignore
dimensionality_reduction.checkpoint_dir = checkpoint_dir # pyright: ignore

self.embedding_model = embedding_model
self.embedding_model = embedding_model
Expand Down
7 changes: 4 additions & 3 deletions kura/meta_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ class ClusterLabel(BaseModel):

@field_validator("higher_level_cluster")
def validate_higher_level_cluster(cls, v: str, info: ValidationInfo) -> str:
if v not in info.context["candidate_clusters"]:
candidate_clusters = info.context["candidate_clusters"] # pyright: ignore
if v not in candidate_clusters:
raise ValueError(
f"""
Invalid higher-level cluster: |{v}|

Valid clusters are:
{", ".join(f"|{c}|" for c in info.context["candidate_clusters"])}
{", ".join(f"|{c}|" for c in candidate_clusters)}
"""
)
return v
Expand All @@ -65,7 +66,7 @@ def __init__(

async def generate_candidate_clusters(
self, clusters: list[Cluster], sem: Semaphore
) -> list[Cluster]:
) -> list[str]:
async with sem:
resp = await self.client.chat.completions.create(
messages=[
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,26 @@ docs = [
"mkdocstrings>=0.27.0",
"mkdocstrings-python>=1.13.0",
]
dev = [
"pyright>=1.1.392.post0",
]


[project.scripts]
kura = "kura.cli.cli:app"

[tool.pyright]
include = ["kura"]
exclude = [
"**/node_modules",
"**/__pycache__",
"src/experimental",
"src/typestubs",
"**/tests/**",
]

reportMissingImports = "error"
reportMissingTypeStubs = false

pythonVersion = "3.9"
pythonPlatform = "Linux"
26 changes: 26 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading