Skip to content

Commit

Permalink
🔧 pre-commit format
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Nov 20, 2023
1 parent 8f8bc47 commit a52e626
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 31 deletions.
4 changes: 2 additions & 2 deletions examples/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def ask(question: str) -> str:
def chat_loop() -> None:
while True:
query = input("> ")

if query == "exit":
break

Expand All @@ -29,7 +29,7 @@ def chat_loop() -> None:
history.clear()
print("\033c")
continue

with stream_to(print):
ask(query)

Expand Down
34 changes: 20 additions & 14 deletions examples/console_log_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ def __init__(self, renderer: "Renderer", name: str) -> None:
self.name = name
self.renderer = renderer
self.renderer.add_chain(self)

def render_stream(self, token: str) -> None:
self.renderer.render_stream(token, self)

def close(self) -> None:
self.renderer.remove(self)


class Renderer:
def __init__(self, column_height: int = 3) -> None:
self.column_height = column_height
self.console = Console(height=3)
self.layout = Layout()
self.live = Live(console=self.console, auto_refresh=True, refresh_per_second=30)
self.chains: list[RenderChain] = []

def add_chain(self, chain: RenderChain) -> None:
if not self.live.is_started:
self.live.start()
self.console.height = (len(self.layout.children) + 1) * self.column_height
self.layout.split_column(
*self.layout.children,
Layout(name=chain.id, size=self.column_height)
*self.layout.children, Layout(name=chain.id, size=self.column_height)
)
self.chains.append(chain)

Expand All @@ -50,21 +50,25 @@ def render_stream(self, token: str, chain: RenderChain) -> None:
tokens: int = 0
max_width: int = self.console.width
content_width: int = 0
if isinstance(panel := self.layout[chain.id]._renderable, Panel) and isinstance(panel.renderable, str):
if isinstance(panel := self.layout[chain.id]._renderable, Panel) and isinstance(
panel.renderable, str
):
content_width = self.console.measure(panel.renderable).maximum
if isinstance(panel.title, str) and " " in panel.title:
tokens = int(panel.title.split(" ")[1])
tokens += count_tokens(token)
prev = panel.renderable.replace("\n", " ")
if (max_width - content_width - 5) < 1:
prev = prev[len(token):] + token
prev = prev[len(token) :] + token
else:
prev += token
else:
prev += token
self.layout[chain.id].update(Panel(prev, title=f"({chain.name}) {tokens} tokens"))
self.layout[chain.id].update(
Panel(prev, title=f"({chain.name}) {tokens} tokens")
)
self.live.update(self.layout)

def remove(self, chain: RenderChain) -> None:
self.chains.remove(chain)
self.layout.split_column(
Expand All @@ -75,7 +79,7 @@ def remove(self, chain: RenderChain) -> None:
if not self.chains:
self.live.update(self.layout)
self.live.stop()

def __del__(self) -> None:
self.live.stop()

Expand All @@ -86,14 +90,15 @@ async def generate_poem_async(topic: str) -> str:
"""
return await achain()


@asynccontextmanager
async def log_stream(renderer: Renderer, name: str) -> AsyncGenerator:
render_chain = RenderChain(renderer, name)
async with astream_to(render_chain.render_stream):
yield render_chain

render_chain.close()


async def stream_poem_async(renderer: Renderer, topic: str) -> None:
async with log_stream(renderer, topic):
Expand All @@ -103,13 +108,14 @@ async def stream_poem_async(renderer: Renderer, topic: str) -> None:
async def main() -> None:
topics = ["goldfish", "spacex", "samurai", "python", "javascript", "ai"]
renderer = Renderer()

for topic in topics:
task = asyncio.create_task(stream_poem_async(renderer, topic))
await asyncio.sleep(1)

while not task.done():
await asyncio.sleep(1)


asyncio.run(main())
print("done")
print("done")
6 changes: 3 additions & 3 deletions examples/gather_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Task(BaseModel):
def gather_infos(task_input: str) -> Task:
"""
Task: {task_input}
Based on this task input, gather all task infos.
"""
return chain()
Expand All @@ -22,7 +22,7 @@ def gather_infos(task_input: str) -> Task:
def plan_task(task: Task) -> str:
"""
Task: { task.name }
Based on the task infos, plan the task.
"""
return chain()
Expand All @@ -37,7 +37,7 @@ def main() -> None:
print("description:", task.description)
print("difficulty:", task.difficulty)
print("keywords:", task.keywords)

plan = plan_task(task)
print("\nplan:", plan)

Expand Down
1 change: 1 addition & 0 deletions examples/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ def generate_story_of(topic: str) -> str:
"""
return chain()


with stream_to(print):
generate_story_of("a space cat")
28 changes: 18 additions & 10 deletions funcchain/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

class AsyncStreamHandler(AsyncCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""

def __init__(self, fn: Callable[[str], Awaitable[None] | None], default_kwargs: dict) -> None:

def __init__(
self, fn: Callable[[str], Awaitable[None] | None], default_kwargs: dict
) -> None:
self.fn = fn
self.default_kwargs = default_kwargs
self.cost: float = 0.0
self.tokens: int = 0

async def on_chat_model_start(
self,
serialized: dict[str, Any],
Expand All @@ -37,7 +39,7 @@ async def on_chat_model_start(
# print("token_counting", message.content)
# # self.tokens += count_tokens(message)
pass

async def on_llm_new_token(
self,
token: str,
Expand All @@ -52,7 +54,7 @@ async def on_llm_new_token(
await self.fn(token, **self.default_kwargs)
else:
self.fn(token, **self.default_kwargs)

async def on_llm_end(
self,
response: LLMResult,
Expand All @@ -66,11 +68,15 @@ async def on_llm_end(
print("\n")


stream_handler: ContextVar[AsyncStreamHandler | None] = ContextVar("stream_handler", default=None)
stream_handler: ContextVar[AsyncStreamHandler | None] = ContextVar(
"stream_handler", default=None
)


@contextmanager
def stream_to(fn: Callable[[str], None], **kwargs: Any) -> Generator[AsyncStreamHandler, None, None]:
def stream_to(
fn: Callable[[str], None], **kwargs: Any
) -> Generator[AsyncStreamHandler, None, None]:
"""
Stream the llm tokens to a given function.
Expand All @@ -87,10 +93,12 @@ def stream_to(fn: Callable[[str], None], **kwargs: Any) -> Generator[AsyncStream


@asynccontextmanager
async def astream_to(fn: Callable[[str], Awaitable[None] | None], **kwargs: Any) -> AsyncGenerator[AsyncStreamHandler, None]:
async def astream_to(
fn: Callable[[str], Awaitable[None] | None], **kwargs: Any
) -> AsyncGenerator[AsyncStreamHandler, None]:
"""
Asyncronously stream the llm tokens to a given function.
Example:
>>> async with astream_to(print):
... # your chain calls here
Expand All @@ -100,4 +108,4 @@ async def astream_to(fn: Callable[[str], Awaitable[None] | None], **kwargs: Any)
cb = AsyncStreamHandler(fn, kwargs)
stream_handler.set(cb)
yield cb
stream_handler.set(None)
stream_handler.set(None)
9 changes: 7 additions & 2 deletions funcchain/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich import print

from .function_frame import get_parent_frame


def retry_parse(fn: Any) -> Any:
"""
Expand All @@ -19,9 +19,11 @@ def retry_parse(fn: Any) -> Any:
- OutputParserException: If the output cannot be parsed.
"""
from ..settings import settings

retry = settings.RETRY_PARSE

if iscoroutinefunction(fn):

@wraps(fn)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
for r in range(retry):
Expand All @@ -35,6 +37,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
return async_wrapper

else:

@wraps(fn)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
for r in range(retry):
Expand All @@ -50,6 +53,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:

def log_openai_callback(fn: Any) -> Any:
if not iscoroutinefunction(fn):

@wraps(fn)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
if (chain := args[0]) and isinstance(chain, RunnableSequence):
Expand All @@ -61,6 +65,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return sync_wrapper

else:

@wraps(fn)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
if (chain := args[0]) and isinstance(chain, RunnableSequence):
Expand Down

0 comments on commit a52e626

Please sign in to comment.