Skip to content

Commit

Permalink
wrapper for invoke_tools that acts on raw token ids
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2023
1 parent feb4819 commit 86fb620
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,7 +21,8 @@
'beartype',
'einops>=0.4',
'torch>=1.6',
'tqdm'
'tqdm',
'x-clip'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
12 changes: 12 additions & 0 deletions toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def invoke_tools(
replace_ = partial(replace_fn, registry, delimiter = delimiter)
return re.sub(find_functions_regex, replace_, text)

def invoke_tools_on_batch_sequences(
registry: dict[str, Callable],
token_ids: torch.Tensor,
*,
encode: Callable,
decode: Callable,
delimiter: str = '→'
) -> torch.Tensor:
all_texts = [decode(one_seq_token_ids) for one_seq_token_ids in token_ids]
all_texts_with_api_calls = [invoke_tools(registry, text, delimiter) for text in all_texts]
return encode(all_texts_with_api_calls)

# sampling api related functions
# they do greedy sampling, but encourage sampling api calls by auto-selecting <api> when that token is in the top k = 10

Expand Down

0 comments on commit 86fb620

Please sign in to comment.