diff --git a/README.md b/README.md
index bbe6981..c81264f 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,10 @@ Implementation of Toolformer, Lan
- Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
+- Enrico for getting the ball rolling with the initial commit of different tools!
+
+- Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect).
+
## Install
```bash
@@ -16,6 +20,8 @@ $ pip install toolformer-pytorch
## Usage
+The main novelty of the paper is a way to filter out what is sampled from a bootstrapped transformer, a manual natural selection of what is sampled for further finetining.
+
```python
import torch
@@ -65,10 +71,34 @@ filtered_results = filter_tokens_with_api_response(
)
```
+To invoke the tools on a string generated by the language model, use `invoke_tools`
+
+```python
+from toolformer_pytorch import invoke_tools
+
+def inc(i):
+ return i + 1
+
+def dec(i):
+ return i - 1
+
+function_registry = dict(
+ inc = inc,
+ dec = dec
+)
+
+text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'
+
+invoke_tools(function_registry, text)
+
+# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]
+```
+
## Todo
-- [ ] create custom generate function for palm that can do external API calls
- - [ ] allow for generating tokens at different cursor indices
+- [x] create custom generate function for palm that can do external API calls
+ - [x] allow for generating tokens at different cursor indices
+ - [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output
- [ ] do end-to-end training in `Toolformer`
- [ ] hook up gpt-j
- [ ] test for a simple calculator eval dataset
diff --git a/setup.py b/setup.py
index 042d519..a24b9a0 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
- version = '0.0.6',
+ version = '0.0.7',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
diff --git a/toolformer_pytorch/__init__.py b/toolformer_pytorch/__init__.py
index 97d61be..0348c0b 100644
--- a/toolformer_pytorch/__init__.py
+++ b/toolformer_pytorch/__init__.py
@@ -4,5 +4,6 @@
Toolformer,
filter_tokens_with_api_response,
sample,
- sample_with_api_call
+ sample_with_api_call,
+ invoke_tools
)
diff --git a/toolformer_pytorch/toolformer_pytorch.py b/toolformer_pytorch/toolformer_pytorch.py
index 31d6729..36b010f 100644
--- a/toolformer_pytorch/toolformer_pytorch.py
+++ b/toolformer_pytorch/toolformer_pytorch.py
@@ -1,3 +1,5 @@
+import re
+
from functools import partial, wraps
from collections import namedtuple
@@ -9,7 +11,7 @@
from toolformer_pytorch.palm import PaLM
from beartype import beartype
-from beartype.typing import Callable, Optional
+from beartype.typing import Callable, Optional, Union
from tqdm import tqdm
@@ -58,7 +60,72 @@ def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):
# invoking api call functions
+def is_valid_string(s):
+ return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s))
+
+def is_valid_integer(s):
+ return exists(re.fullmatch(r"[+-]?\d+", s))
+
+def is_valid_float(s):
+ return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s))
+
+def parse_param(s: str) -> Optional[Union[int, float, str]]:
+ if is_valid_string(s):
+ return str(s)
+ elif is_valid_integer(s):
+ return int(s)
+ elif is_valid_float(s):
+ return float(s)
+
+ return None
+
+@beartype
+def replace_fn(
+ registry: dict[str, Callable],
+ m,
+ delimiter = '→'
+):
+ orig_text = m.group(0)
+
+ function_name = m.group(1)
+
+ # unable to find function in registry
+
+ if function_name not in registry:
+ return orig_text
+
+ fn = registry[function_name]
+
+ params = m.group(2).split(',')
+ params = list(map(lambda s: s.strip(), params))
+ params = list(map(parse_param, params))
+
+ # if any of the parameters are not parseable, return
+
+ if any([(not exists(p)) for p in params]):
+ return orig_text
+
+ # just return original text if there is some error with the function
+
+ try:
+ out = fn(*params)
+ except:
+ return orig_text
+
+ # return original text with the output delimiter and the stringified output
+
+ return f'{orig_text[:-1]} {delimiter} {str(out)}]'
+
+# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output
+def invoke_tools(
+ registry: dict[str, Callable],
+ text: str,
+ delimiter: str = '→'
+) -> str:
+ find_functions_regex = r'\[(\w+)\(([^)]*)\)\]'
+ replace_ = partial(replace_fn, registry, delimiter = delimiter)
+ return re.sub(find_functions_regex, replace_, text)
# sampling api related functions
# they do greedy sampling, but encourage sampling api calls by auto-selecting when that token is in the top k = 10