Skip to content

Commit

Permalink
Merge pull request #2 from leftmove/allow-interruptions
Browse files Browse the repository at this point in the history
Allow interruptions
  • Loading branch information
leftmove authored May 2, 2024
2 parents 30d6952 + 46e2475 commit c862676
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 3 deletions.
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ Cria is a library for programmatically running Large Language Models through Pyt
- [Follow-Up](#follow-up)
- [Clear Message History](#clear-message-history)
- [Passing In Custom Context](#passing-in-custom-context)
- [Interrupting](#interrupting)
- [With Message History](#with-message-history)
- [Without Message History](#without-message-history)
- [Multiple Models and Parallel Conversations](#multiple-models-and-parallel-conversations)
- [Models](#models)
- [With](#with-model)
Expand All @@ -53,9 +56,11 @@ prompt = "Who is the CEO of OpenAI?"
for chunk in ai.chat(prompt):
print(chunk, end="")
```

```
>>> The CEO of OpenAI is Sam Altman!
```

or, you can run this more configurable example.

```python
Expand All @@ -66,6 +71,7 @@ with cria.Model() as ai:
response = ai.chat(prompt, stream=False)
print(response)
```

```
>>> The CEO of OpenAI is Sam Altman!
```
Expand Down Expand Up @@ -194,6 +200,64 @@ The available roles for messages are:

The prompt parameter will always be appended to messages under the `user` role, to override this, you can choose to pass in nothing for `prompt`.

### Interrupting

#### With Message History

If you are streaming messages with Cria, you can interrupt the prompt mid way.

```python
response = ""
max_token_length = 5

prompt = "Who is the CEO of OpenAI?"
for i, chunk in enumerate(ai.chat(prompt)):

if i >= max_token_length:
ai.stop()

print(chunk, end="") # The CEO of OpenAI is
```

```python
response = ""
max_token_length = 5

prompt = "Who is the CEO of OpenAI?"
for i, chunk in enumerate(ai.generate(prompt)):

if i >= max_token_length:
ai.stop()

print(chunk, end="") # The CEO of OpenAI is
```

In the examples, after the AI generates five tokens (units of text that are usually a couple of characters long), text generation is stopped via the `stop` method. After `stop` is called, you can safely `break` out of the `for` loop.

#### Without Message History

By default, Cria automatically saves responses in message history, even if the stream is interrupted. To prevent this behaviour though, you can pass in the `allow_interruption` boolean.

```python
ai = cria.Cria(allow_interruption=False)

response = ""
max_token_length = 5

prompt = "Who is the CEO of OpenAI?"
for i, chunk in enumerate(ai.chat(prompt)):

if i >= max_token_length:
ai.stop()
break

print(chunk, end="") # The CEO of OpenAI is

prompt = "Tell me more about him."
for chunk in ai.chat(prompt):
print(chunk, end="") # I apologize, but I don't have any information about "him" because the conversation just started...
```

### Multiple Models and Parallel Conversations

#### Models
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cria"
version = "1.6.2"
version = "1.6.5"
authors = [{ name = "leftmove", email = "[email protected]" }]
description = "Run AI locally with as little friction as possible"
readme = "README.md"
Expand All @@ -17,7 +17,7 @@ Issues = "https://github.com/leftmove/cria/issues"

[tool.poetry]
name = "cria"
version = "1.6.2"
version = "1.6.5"
description = "Run AI locally with as little friction as possible."
authors = ["leftmove"]
license = "MIT"
Expand Down
36 changes: 35 additions & 1 deletion src/cria.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,31 @@ def chat_stream(self, messages, **kwargs):
ai = ollama

response = ""
self.running = True

for chunk in ai.chat(model=model, messages=messages, stream=True, **kwargs):
if self.stop_stream:
if self.allow_interruption:
messages.append({"role": "assistant", "content": response})
self.running = False
return
content = chunk["message"]["content"]
response += content
yield content

self.running = False

messages.append({"role": "assistant", "content": response})
self.messages = messages

stop_stream = False

def stop(self):
if self.running:
self.stop_stream = True
else:
raise ValueError("No active chat stream to stop.")

def chat(
self,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -70,10 +86,19 @@ def generate_stream(self, prompt, **kwargs):
model = self.model
ai = ollama

response = ""
self.running = True

for chunk in ai.generate(model=model, prompt=prompt, stream=True, **kwargs):
if self.stop_stream:
self.running = False
return
content = chunk["response"]
response += content
yield content

self.running = False

def generate(self, prompt: str, stream: Optional[bool] = True, **kwargs) -> str:
model = self.model
ai = ollama
Expand Down Expand Up @@ -156,13 +181,15 @@ def __init__(
standalone: Optional[bool] = False,
run_subprocess: Optional[bool] = False,
capture_output: Optional[bool] = False,
allow_interruption: Optional[bool] = True,
silence_output: Optional[bool] = False,
close_on_exit: Optional[bool] = True,
) -> None:
self.run_subprocess = run_subprocess
self.capture_output = capture_output
self.silence_output = silence_output
self.close_on_exit = close_on_exit
self.allow_interruption = allow_interruption

ollama_process = find_process(["ollama", "serve"])
self.ollama_process = ollama_process
Expand Down Expand Up @@ -202,8 +229,11 @@ def __init__(
if not standalone:
self.llm = find_process(["ollama", "run", self.model])

if self.llm and run_subprocess:
if run_subprocess and self.llm:
self.llm.kill()
self.llm = None

if not self.llm:
self.llm = subprocess.Popen(
["ollama", "run", self.model],
stdout=subprocess.DEVNULL,
Expand All @@ -216,6 +246,8 @@ def __init__(
if close_on_exit and not standalone:
atexit.register(lambda: self.llm.kill())

messages = DEFAULT_MESSAGE_HISTORY

def output(self):
ollama_subprocess = self.ollama_subrprocess
if not ollama_subprocess:
Expand All @@ -240,6 +272,7 @@ def __init__(
model: Optional[str] = DEFAULT_MODEL,
run_attached: Optional[bool] = False,
run_subprocess: Optional[bool] = False,
allow_interruption: Optional[bool] = True,
capture_output: Optional[bool] = False,
silence_output: Optional[bool] = False,
close_on_exit: Optional[bool] = True,
Expand All @@ -253,6 +286,7 @@ def __init__(
)

self.capture_output = capture_output
self.allow_interruption = allow_interruption
self.silence_output = silence_output
self.close_on_exit = close_on_exit

Expand Down
20 changes: 20 additions & 0 deletions test/test_current.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from src import cria

ai = cria.Model(allow_interruption=False)

response = ""
max_token_length = 5

prompt = "Who is the CEO of OpenAI?"
for i, chunk in enumerate(ai.chat(prompt)):
if i >= max_token_length:
ai.stop()
break

print(chunk, end="") # The CEO of OpenAI is

prompt = "Tell me more about him."
for chunk in ai.chat(prompt):
print(
chunk, end=""
) # I apologize, but I don't have any information about "him" because the conversation just started...
36 changes: 36 additions & 0 deletions test/test_interrupt_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from src import cria

ai = cria.Cria()
interruption_length = 5

prompt = "Who is the CEO of OpenAI?"
response = ""
chunks = []

for i, chunk in enumerate(ai.chat(prompt)):
chunks += chunk
response += chunk
print(chunk, end="")

if i >= interruption_length:
ai.stop()


class TestChat(unittest.TestCase):
def test_response(self):
self.assertIsNot(response, "")

def test_chunks(self):
self.assertIsNot(chunks, [])

def test_interruption_length(self):
self.assertIsNot(len(ai.messages[-1]["content"]), 0)

def test_response_length(self):
self.assertIs(i, interruption_length)


if __name__ == "__main__":
unittest.main()
36 changes: 36 additions & 0 deletions test/test_interrupt_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from src import cria

ai = cria.Cria()
interruption_length = 5

prompt = "Who is the CEO of OpenAI?"
response = ""
chunks = []

for i, chunk in enumerate(ai.chat(prompt)):
chunks += chunk
response += chunk
print(chunk, end="")

if i >= interruption_length:
ai.stop()


class TestChat(unittest.TestCase):
def test_response(self):
self.assertIsNot(response, "")

def test_chunks(self):
self.assertIsNot(chunks, [])

def test_interruption_length(self):
self.assertIsNot(len(ai.messages[-1]["content"]), 0)

def test_response_length(self):
self.assertIs(i, interruption_length)


if __name__ == "__main__":
unittest.main()

0 comments on commit c862676

Please sign in to comment.