From ef3871edcccca86138b05ae6b8f2fa645b2ab691 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 6 Apr 2023 13:50:15 -0700 Subject: [PATCH] another big part needed for the repository to work, also give appreciation to @conceptofmind as well as the current best AI --- README.md | 34 +++++++++++- setup.py | 2 +- toolformer_pytorch/__init__.py | 3 +- toolformer_pytorch/toolformer_pytorch.py | 69 +++++++++++++++++++++++- 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bbe6981..c81264f 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,10 @@ Implementation of Toolformer, Lan - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research +- Enrico for getting the ball rolling with the initial commit of different tools! + +- Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect). + ## Install ```bash @@ -16,6 +20,8 @@ $ pip install toolformer-pytorch ## Usage +The main novelty of the paper is a way to filter out what is sampled from a bootstrapped transformer, a manual natural selection of what is sampled for further finetining. + ```python import torch @@ -65,10 +71,34 @@ filtered_results = filter_tokens_with_api_response( ) ``` +To invoke the tools on a string generated by the language model, use `invoke_tools` + +```python +from toolformer_pytorch import invoke_tools + +def inc(i): + return i + 1 + +def dec(i): + return i - 1 + +function_registry = dict( + inc = inc, + dec = dec +) + +text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]' + +invoke_tools(function_registry, text) + +# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)] +``` + ## Todo -- [ ] create custom generate function for palm that can do external API calls - - [ ] allow for generating tokens at different cursor indices +- [x] create custom generate function for palm that can do external API calls + - [x] allow for generating tokens at different cursor indices + - [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output - [ ] do end-to-end training in `Toolformer` - [ ] hook up gpt-j - [ ] test for a simple calculator eval dataset diff --git a/setup.py b/setup.py index 042d519..a24b9a0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'toolformer-pytorch', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'Toolformer - Pytorch', author = 'Phil Wang', diff --git a/toolformer_pytorch/__init__.py b/toolformer_pytorch/__init__.py index 97d61be..0348c0b 100644 --- a/toolformer_pytorch/__init__.py +++ b/toolformer_pytorch/__init__.py @@ -4,5 +4,6 @@ Toolformer, filter_tokens_with_api_response, sample, - sample_with_api_call + sample_with_api_call, + invoke_tools ) diff --git a/toolformer_pytorch/toolformer_pytorch.py b/toolformer_pytorch/toolformer_pytorch.py index 31d6729..36b010f 100644 --- a/toolformer_pytorch/toolformer_pytorch.py +++ b/toolformer_pytorch/toolformer_pytorch.py @@ -1,3 +1,5 @@ +import re + from functools import partial, wraps from collections import namedtuple @@ -9,7 +11,7 @@ from toolformer_pytorch.palm import PaLM from beartype import beartype -from beartype.typing import Callable, Optional +from beartype.typing import Callable, Optional, Union from tqdm import tqdm @@ -58,7 +60,72 @@ def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1): # invoking api call functions +def is_valid_string(s): + return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s)) + +def is_valid_integer(s): + return exists(re.fullmatch(r"[+-]?\d+", s)) + +def is_valid_float(s): + return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s)) + +def parse_param(s: str) -> Optional[Union[int, float, str]]: + if is_valid_string(s): + return str(s) + elif is_valid_integer(s): + return int(s) + elif is_valid_float(s): + return float(s) + + return None + +@beartype +def replace_fn( + registry: dict[str, Callable], + m, + delimiter = '→' +): + orig_text = m.group(0) + + function_name = m.group(1) + + # unable to find function in registry + + if function_name not in registry: + return orig_text + + fn = registry[function_name] + + params = m.group(2).split(',') + params = list(map(lambda s: s.strip(), params)) + params = list(map(parse_param, params)) + + # if any of the parameters are not parseable, return + + if any([(not exists(p)) for p in params]): + return orig_text + + # just return original text if there is some error with the function + + try: + out = fn(*params) + except: + return orig_text + + # return original text with the output delimiter and the stringified output + + return f'{orig_text[:-1]} {delimiter} {str(out)}]' + +# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output +def invoke_tools( + registry: dict[str, Callable], + text: str, + delimiter: str = '→' +) -> str: + find_functions_regex = r'\[(\w+)\(([^)]*)\)\]' + replace_ = partial(replace_fn, registry, delimiter = delimiter) + return re.sub(find_functions_regex, replace_, text) # sampling api related functions # they do greedy sampling, but encourage sampling api calls by auto-selecting when that token is in the top k = 10