From 8830a6d3134b26064375e5af4c17daab47c9afe4 Mon Sep 17 00:00:00 2001 From: Armin Buescher Date: Fri, 17 Jan 2025 11:40:18 +0100 Subject: [PATCH] Add option for array and list prompt input --- sillm/core/llm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sillm/core/llm.py b/sillm/core/llm.py index 6a43131..5168560 100644 --- a/sillm/core/llm.py +++ b/sillm/core/llm.py @@ -445,7 +445,7 @@ def completion(self, *args, **kwargs) -> str: def generate(model, tokenizer: Tokenizer, - prompt: str, + prompt: str | list | mx.array, cache: KVCache = None, max_tokens: int = 2048, temperature: float = 0.0, @@ -463,8 +463,15 @@ def generate(model, ): start = time.perf_counter() - # Tokenize inputs - inputs = mx.array(tokenizer.encode(prompt)) + # Pre-process inputs + if isinstance(prompt, str): + inputs = mx.array(tokenizer.encode(prompt)) + elif isinstance(prompt, list): + inputs = mx.array(prompt) + elif isinstance(prompt, mx.array): + inputs = prompt + else: + raise ValueError("Prompt must be a string, list of tokens, or MX array") # Initialize metadata timing = {