Skip to content

Commit

Permalink
a function needed for the final sampling with api calling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2023
1 parent 240ebab commit c794e75
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion toolformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from toolformer_pytorch.toolformer_pytorch import (
Toolformer,
filter_tokens_with_api_response,
sample
sample,
sample_with_api_call
)
52 changes: 51 additions & 1 deletion toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,19 @@ def all_contains_id(t: torch.Tensor, token_id: int):
mask = t == token_id
return mask.any(dim = -1).all()

def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):
assert occurrence > 0
mask = (t == token_id)

has_occurred = mask.cumsum(dim = -1)
has_occurred = F.pad(has_occurred, (1, 0), value = 0.)

return (has_occurred < occurrence).sum(dim = -1).long()

# sampling api related functions
# they do greedy sampling, but encourage sampling api calls by auto-selecting <api> when that token is in the top k = 10

@beartype
@torch.no_grad()
def sample(
model: nn.Module,
Expand Down Expand Up @@ -81,7 +91,7 @@ def sample(
# sampling positions - different sequences have different cursors

positions = default(positions, torch.zeros((batch_size,), device = device, dtype = torch.long))
assert (positions <= prime_length).all() and (positions < max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'
assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'

# eval model

Expand Down Expand Up @@ -166,6 +176,46 @@ def create_api_token_mask(num_tokens, api_start_token_id):

return output

@beartype
@torch.no_grad()
def sample_with_api_call(
model: nn.Module,
*,
seq_len,
call_apis: Callable,
prime: torch.Tensor,
api_end_token_id: int,
occurrence = 1,
**kwargs
):
sampled = sample(
model = model,
prime = prime,
seq_len = seq_len,
**kwargs
)

sampled = call_apis(sampled)

sampled_seq_len = sampled.shape[-1]
null_positions = sampled_seq_len + 1 # handle sequences that do not have api calls

pos_starting_at_end_of_api = find_indices_of(
sampled,
api_end_token_id,
occurrence = occurrence
)

resample_after_api_calls = sample(
model = model,
prime = sampled,
seq_len = sampled_seq_len,
positions = (pos_starting_at_end_of_api + 1).clamp(max = null_positions), # start at the position right after the </api>
**kwargs
)

return resample_after_api_calls

# the main contribution of the paper is simply the filtering equations presented in section 2

def default_weight_fn(t):
Expand Down

0 comments on commit c794e75

Please sign in to comment.