Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify configuration and document it #69

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ Sam uses OpenAI's assistant API to fine-tune ChatGPT to:

to provide a work-colleague like experience.

## Features

Like any good co-worker Sam can

* search the web,
* browse websites,
* search your companies products,
* read internal documents,
* send emails,
* create GitHub issues,

and soon spend half the day in meetings and the other half in the kitchen.

## Sneak peek of Sam in action

![screenshot.png](screenshot.png)
Expand Down
68 changes: 68 additions & 0 deletions app.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,74 @@
"OPENAI_ASSISTANT_ID": {
"description": "Your OpenAI assistant ID, starting with 'asst_'",
"required": true
},
"TTS_VOICE": {
"description": "The voice to use for text-to-speech. Defaults to 'alloy'.",
"required": false
},
"TTS_MODEL": {
"description": "The model to use for text-to-speech. Defaults to 'tts-1-hd'.",
"required": false
},
"MAX_PROMPT_TOKENS": {
"description": "The maximum number of tokens to use in a prompt.",
"required": false
},
"MAX_COMPLETION_TOKENS": {
"description": "The maximum number of tokens to use in a response.",
"required": false
},
"RANDOM_RUN_RATIO": {
"description": "How often the bot randomly responds in a group channel. Defaults to 0 (never).",
"required": false
},
"TIMEZONE": {
"description": "The timezone to use for scheduling. Defaults to 'UTC'.",
"required": false
},
"EMAIL_URL": {
"description": "The URL to send emails to. Disabled if not set.",
"required": false
},
"FROM_EMAIL": {
"description": "The email address to send emails from. Disabled if not set.",
"required": false
},
"EMAIL_WHITE_LIST": {
"description": "A comma-separated list of email addresses to allow sending emails to. Disabled if not set.",
"required": false
},
"BRAVE_SEARCH_API_KEY": {
"description": "Your Brave Search API key. Disabled if not set.",
"required": false
},
"BRAVE_SEARCH_LONGITUDE": {
"description": "The longitude to use for Brave Search. Disabled if not set.",
"required": false
},
"BRAVE_SEARCH_LATITUDE": {
"description": "The latitude to use for Brave Search. Disabled if not set.",
"required": false
},
"SENTRY_DSN": {
"description": "Your Sentry DSN. Disabled if not set.",
"required": false
},
"GITHUB_REPOS": {
"description": "A comma-separated list of GitHub the bot can post to. Disabled if not set.",
"required": false
},
"ALGOLIA_SEARCH_API_KEY": {
"description": "Your Algolia search API key. Disabled if not set.",
"required": false
},
"ALGOLIA_APPLICATION_ID": {
"description": "Your Algolia application ID. Disabled if not set.",
"required": false
},
"ALGOLIA_SEARCH_INDEX": {
"description": "The Algolia search index to use. Disabled if not set.",
"required": false
}
},
"formation": {
Expand Down
25 changes: 25 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ testpaths = ["tests"]
GITHUB_REPOS = 'voiio/sam'
BRAVE_SEARCH_LATITUDE = '37.7749'
BRAVE_SEARCH_LONGITUDE = '-122.4194'
FROM_EMAIL = "Sam <[email protected]>"

[tool.coverage.run]
source = ["sam"]
Expand All @@ -109,3 +110,27 @@ combine_as_imports = true

[tool.pydocstyle]
add_ignore = "D1"


[tool.sam.tools.send_email]
path = "sam.tools:send_email"
additional_instructions = "You may ask for confirmation before sending an email."

[tool.sam.tools.fetch_website]
path = "sam.tools:fetch_website"
additional_instructions = "You may fetch multiple websites and browse the sites."

[tool.sam.tools.fetch_coworker_contacts]
path = "sam.slack:fetch_coworker_contacts"

[tool.sam.tools.web_search]
path = "sam.contrib.brave.tools:search"
additional_instructions = "You MUST ALWAYS always fetch a website and read it."

[tool.sam.tools.create_github_issue]
path = "sam.contrib.github.tools:create_github_issue"
additional_instructions = "You MUST ALWAYS write the issue in English."

[tool.sam.tools.platform_search]
path = "sam.contrib.algolia.tools:search"
additional_instructions = "The voiio platform should be searched before searching the web."
54 changes: 28 additions & 26 deletions sam/bot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import datetime
import json
import logging
from pathlib import Path

import openai
from redis import asyncio as redis

from . import config, tools, utils
from . import config, utils
from .typing import AUDIO_FORMATS, Roles, RunStatus

logger = logging.getLogger(__name__)
Expand All @@ -22,7 +21,7 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context

Raises:
RecursionError: If the run status is not "completed" after 10 retries.
IOError: If the run status is not "completed" or "requires_action".
OSError: If the run status is not "completed" or "requires_action".
ValueError: If the run requires tools but none are provided.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
Expand All @@ -42,15 +41,15 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context
case RunStatus.COMPLETED:
return
case _:
raise IOError(f"Run {run.id} failed with status {run.status}")
raise OSError(f"Run {run.id} failed with status {run.status}")


async def call_tools(run: openai.types.beta.threads.Run, **context) -> None:
"""
Call the tools required by the run.

Raises:
IOError: If a tool is not found.
OSError: If a tool is not found.
ValueError: If the run does not require any tools.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
Expand All @@ -59,12 +58,12 @@ async def call_tools(run: openai.types.beta.threads.Run, **context) -> None:
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
try:
fn = getattr(tools, tool_call.function.name)
except AttributeError as e:
fn = config.TOOLS[tool_call.function.name]
except KeyError as e:
await client.beta.threads.runs.cancel(
run_id=run.id, thread_id=run.thread_id
)
raise IOError(
raise OSError(
f"Tool {tool_call.function.name} not found, cancelling run {run.id}"
) from e
try:
Expand All @@ -73,18 +72,20 @@ async def call_tools(run: openai.types.beta.threads.Run, **context) -> None:
await client.beta.threads.runs.cancel(
run_id=run.id, thread_id=run.thread_id
)
raise IOError(
raise OSError(
f"Invalid arguments for tool {tool_call.function.name}, cancelling run {run.id}"
) from e
logger.info("Running tool %s", tool_call.function.name)
logger.debug("Tool %s arguments: %r", tool_call.function.name, kwargs)
try:
output = await fn(**kwargs, **context)
output = fn(**kwargs, _context=context)
except Exception as e:
await client.beta.threads.runs.cancel(
run_id=run.id, thread_id=run.thread_id
)
raise e
raise OSError(
f"Tool {tool_call.function.name} failed, cancelling run {run.id}"
) from e
tool_outputs.append(
{
"tool_call_id": tool_call.id,
Expand Down Expand Up @@ -116,24 +117,29 @@ async def execute_run(
logger.debug("Additional instructions: %r", additional_instructions)
logger.debug("Context: %r", context)
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
tools = [
*config.TOOLS.keys(),
{"type": "file_search"},
]
logger.debug("Tools: %r", tools)
run = await client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
tools=[
utils.func_to_tool(tools.send_email),
utils.func_to_tool(tools.web_search),
utils.func_to_tool(tools.platform_search),
utils.func_to_tool(tools.fetch_website),
utils.func_to_tool(tools.fetch_coworker_emails),
utils.func_to_tool(tools.create_github_issue),
{"type": "file_search"},
],
# OpenAI suggests a limit of 20,000 tokens for the prompt using the file search tool.
# See also: https://platform.openai.com/docs/assistants/how-it-works/max-completion-and-max-prompt-tokens
max_prompt_tokens=(
min(20000, config.MAX_PROMPT_TOKENS)
if file_search
else config.MAX_PROMPT_TOKENS
),
max_completion_tokens=config.MAX_COMPLETION_TOKENS,
tools=tools,
tool_choice={"type": "file_search"} if file_search else "auto",
)
try:
await complete_run(run.id, thread_id, **context)
except (RecursionError, IOError, ValueError):
except (RecursionError, OSError, ValueError):
logger.exception("Run %s failed", run.id)
return "🤯"

Expand Down Expand Up @@ -267,10 +273,6 @@ async def get_thread_id(slack_id) -> str:
thread = await openai.AsyncOpenAI().beta.threads.create()
thread_id = thread.id

midnight = datetime.datetime.combine(
datetime.date.today(), datetime.time.max, tzinfo=config.TIMEZONE
)

await redis_client.set(slack_id, thread_id, exat=int(midnight.timestamp()))
await redis_client.set(slack_id, thread_id)

return thread_id
95 changes: 63 additions & 32 deletions sam/config.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,81 @@
from __future__ import annotations

import enum
import os
import re
import tomllib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Iterable
from zoneinfo import ZoneInfo

SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_ASSISTANT_ID = os.getenv("OPENAI_ASSISTANT_ID")
TTS_VOICE = os.getenv("TTS_VOICE", "alloy")
TTS_MODEL = os.getenv("TTS_MODEL", "tts-1-hd")
MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "20480"))
REDIS_URL = os.getenv("REDIS_URL", "redis:///")
RANDOM_RUN_RATIO = float(os.getenv("RANDOM_RUN_RATIO", "0"))
TIMEZONE = ZoneInfo(os.getenv("TIMEZONE", "UTC"))
BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY")
BRAVE_SEARCH_LONGITUDE = os.getenv("BRAVE_SEARCH_LONGITUDE")
BRAVE_SEARCH_LATITUDE = os.getenv("BRAVE_SEARCH_LATITUDE")
SENTRY_DSN = os.getenv("SENTRY_DSN")
GITHUB_REPOS = enum.StrEnum(
"GITHUB_REPOS",
{repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo},
from sam.utils import AssistantConfig, Tool

# General
#: The URL of the Redis database server.
REDIS_URL: str = os.getenv("REDIS_URL", "redis:///")
#: How often the bot randomly responds in a group channel.
RANDOM_RUN_RATIO: float = float(os.getenv("RANDOM_RUN_RATIO", "0"))
#: The timezone the bot "lives" in.
TIMEZONE: ZoneInfo = ZoneInfo(os.getenv("TIMEZONE", "UTC"))
#: The Brave Search API key for web search.

# OpenAI
#: The OpenAI API key.
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY")
#: The OpenAI assistant ID used for the bot.
OPENAI_ASSISTANT_ID: str = os.getenv("OPENAI_ASSISTANT_ID")
#: The OpenAI model used for text-to-speech.
TTS_VOICE: str = os.getenv("TTS_VOICE", "alloy")
#: The OpenAI model used for speech-to-text.
TTS_MODEL: str = os.getenv("TTS_MODEL", "tts-1-hd")
#: The maximum number of tokens allowed in a prompt.
MAX_PROMPT_TOKENS: int | None = (
int(os.getenv("MAX_PROMPT_TOKENS")) if "MAX_PROMPT_TOKENS" in os.environ else None
)
#: The maximum number of tokens allowed in a completion.
MAX_COMPLETION_TOKENS: int | None = (
int(os.getenv("MAX_COMPLETION_TOKENS"))
if "MAX_COMPLETION_TOKENS" in os.environ
else None
)

# Slack
#: The Slack bot token, prefixed with `xoxb-`.
SLACK_BOT_TOKEN: str = os.getenv("SLACK_BOT_TOKEN")
#: The Slack app token, prefixed with `xapp-`.
SLACK_APP_TOKEN: str = os.getenv("SLACK_APP_TOKEN")

# Email
#: The email address the bot sends emails from.
EMAIL_URL: str | None = os.getenv("EMAIL_URL")
FROM_EMAIL: str | None = os.getenv("FROM_EMAIL")
EMAIL_WHITELIST_PATTERN: re.Pattern | None = (
re.compile(os.getenv("EMAIL_WHITELIST_PATTERN"))
if "EMAIL_WHITELIST_PATTERN" in os.environ
else None
)

# Sentry
#: The Sentry DSN for Sentry based error reporting.
SENTRY_DSN: str = os.getenv("SENTRY_DSN")

@dataclass
class AssistantConfig:
name: str
assistant_id: str
instructions: list[str]
project: str

@cached_property
def system_prompt(self):
return "\n\n".join(
Path(instruction).read_text() for instruction in self.instructions
)
def load_tools() -> dict[str, callable]:
with Path("pyproject.toml").open("rb") as fs:
for fn_id, config in (
tomllib.load(fs).get("tool", {}).get("sam", {}).get("tools", {}).items()
):
yield fn_id, Tool(fn_id, **config)


TOOLS = dict(load_tools())

def load_assistants():

def load_assistants() -> Iterable[AssistantConfig]:
with Path("pyproject.toml").open("rb") as fs:
for assistant in (
tomllib.load(fs).get("tool", {}).get("sam", {}).get("assistants", [])
):
yield AssistantConfig(**assistant)


ASSISTANTS = list(load_assistants())
Loading