diff --git a/guidance/llms/_transformers.py b/guidance/llms/_transformers.py index 5eb3356f1..f59ff1460 100644 --- a/guidance/llms/_transformers.py +++ b/guidance/llms/_transformers.py @@ -278,7 +278,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n # trim the cache to what we can use if prefix_match_len < len(self._prefix_cache): # prefix_match_len > 0 and - self._past_key_values = tuple((key[:,:,:prefix_match_len,:],value[:,:,:prefix_match_len,:]) for key,value in self._past_key_values) # TODO: this is specific to the GPT2 tensor layout + self._past_key_values = tuple(tuple(key_value[:,:,:prefix_match_len,:] for key_value in key_values) for key_values in self._past_key_values) # TODO: this works for GPT2 and Long Llama tensor layout. self._prefix_cache = self._prefix_cache[:prefix_match_len] # add support for pattern guidance