Skip to content

Commit

Permalink
Add option for array and list prompt input
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Jan 17, 2025
1 parent 9d837b8 commit 8830a6d
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions sillm/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def completion(self, *args, **kwargs) -> str:

def generate(model,
tokenizer: Tokenizer,
prompt: str,
prompt: str | list | mx.array,
cache: KVCache = None,
max_tokens: int = 2048,
temperature: float = 0.0,
Expand All @@ -463,8 +463,15 @@ def generate(model,
):
start = time.perf_counter()

# Tokenize inputs
inputs = mx.array(tokenizer.encode(prompt))
# Pre-process inputs
if isinstance(prompt, str):
inputs = mx.array(tokenizer.encode(prompt))
elif isinstance(prompt, list):
inputs = mx.array(prompt)
elif isinstance(prompt, mx.array):
inputs = prompt
else:
raise ValueError("Prompt must be a string, list of tokens, or MX array")

# Initialize metadata
timing = {
Expand Down

0 comments on commit 8830a6d

Please sign in to comment.