Skip to content

Commit

Permalink
another big part needed for the repository to work, also give appreci…
Browse files Browse the repository at this point in the history
…ation to @conceptofmind as well as the current best AI
  • Loading branch information
lucidrains committed Apr 6, 2023
1 parent 2bb4b77 commit ef3871e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Implementation of <a href="https://arxiv.org/abs/2302.04761">Toolformer</a>, Lan

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

- <a href="https://github.com/conceptofmind">Enrico</a> for getting the ball rolling with the initial commit of different tools!

- Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect).

## Install

```bash
Expand All @@ -16,6 +20,8 @@ $ pip install toolformer-pytorch

## Usage

The main novelty of the paper is a way to filter out what is sampled from a bootstrapped transformer, a manual natural selection of what is sampled for further finetining.

```python
import torch

Expand Down Expand Up @@ -65,10 +71,34 @@ filtered_results = filter_tokens_with_api_response(
)
```

To invoke the tools on a string generated by the language model, use `invoke_tools`

```python
from toolformer_pytorch import invoke_tools

def inc(i):
return i + 1

def dec(i):
return i - 1

function_registry = dict(
inc = inc,
dec = dec
)

text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'

invoke_tools(function_registry, text)

# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]
```

## Todo

- [ ] create custom generate function for palm that can do external API calls
- [ ] allow for generating tokens at different cursor indices
- [x] create custom generate function for palm that can do external API calls
- [x] allow for generating tokens at different cursor indices
- [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output
- [ ] do end-to-end training in `Toolformer`
- [ ] hook up gpt-j
- [ ] test for a simple calculator eval dataset
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion toolformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
Toolformer,
filter_tokens_with_api_response,
sample,
sample_with_api_call
sample_with_api_call,
invoke_tools
)
69 changes: 68 additions & 1 deletion toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

from functools import partial, wraps
from collections import namedtuple

Expand All @@ -9,7 +11,7 @@
from toolformer_pytorch.palm import PaLM

from beartype import beartype
from beartype.typing import Callable, Optional
from beartype.typing import Callable, Optional, Union

from tqdm import tqdm

Expand Down Expand Up @@ -58,7 +60,72 @@ def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):

# invoking api call functions

def is_valid_string(s):
return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s))

def is_valid_integer(s):
return exists(re.fullmatch(r"[+-]?\d+", s))

def is_valid_float(s):
return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s))

def parse_param(s: str) -> Optional[Union[int, float, str]]:
if is_valid_string(s):
return str(s)
elif is_valid_integer(s):
return int(s)
elif is_valid_float(s):
return float(s)

return None

@beartype
def replace_fn(
registry: dict[str, Callable],
m,
delimiter = '→'
):
orig_text = m.group(0)

function_name = m.group(1)

# unable to find function in registry

if function_name not in registry:
return orig_text

fn = registry[function_name]

params = m.group(2).split(',')
params = list(map(lambda s: s.strip(), params))
params = list(map(parse_param, params))

# if any of the parameters are not parseable, return

if any([(not exists(p)) for p in params]):
return orig_text

# just return original text if there is some error with the function

try:
out = fn(*params)
except:
return orig_text

# return original text with the output delimiter and the stringified output

return f'{orig_text[:-1]} {delimiter} {str(out)}]'

# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output

def invoke_tools(
registry: dict[str, Callable],
text: str,
delimiter: str = '→'
) -> str:
find_functions_regex = r'\[(\w+)\(([^)]*)\)\]'
replace_ = partial(replace_fn, registry, delimiter = delimiter)
return re.sub(find_functions_regex, replace_, text)

# sampling api related functions
# they do greedy sampling, but encourage sampling api calls by auto-selecting <api> when that token is in the top k = 10
Expand Down

0 comments on commit ef3871e

Please sign in to comment.