Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/probcomp/hfppl
Browse files Browse the repository at this point in the history
  • Loading branch information
postylem committed Sep 22, 2023
2 parents 2f7af78 + 1faff43 commit c7b0974
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 134 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ If everything is working, you should see the model generate political news using
A LLaMPPL program is a subclass of the `hfppl.Model` class.

```python
from hfppl import Model, StatefulLM, TokenCategorical, CachedCausalLM
from hfppl import Model, LMContext, TokenCategorical, CachedCausalLM

# A LLaMPPL model subclasses the Model class
class MyModel(Model):
Expand All @@ -42,7 +42,7 @@ class MyModel(Model):
super().__init__()

# A stateful context object for the LLM, initialized with the prompt
self.context = StatefulLM(lm, prompt)
self.context = LMContext(lm, prompt)
self.lm = lm

# The forbidden letter
Expand Down
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ To do so, we write subclass the [`Model`](hfppl.modeling.Model) class:
```python
# examples/no_e.py

from hfppl import Model, StatefulLM, TokenCategorical, CachedCausalLM
from hfppl import Model, LMContext, TokenCategorical, CachedCausalLM

# A LLaMPPL model subclasses the Model class
class MyModel(Model):
Expand All @@ -44,7 +44,7 @@ class MyModel(Model):
super().__init__()

# A stateful context object for the LLM, initialized with the prompt
self.context = StatefulLM(lm, prompt)
self.context = LMContext(lm, prompt)

# The forbidden letter
self.forbidden_tokens = [i for (i, v) in enumerate(lm.vocab)
Expand Down
15 changes: 13 additions & 2 deletions docs/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

## Load your Transformer as a `CachedCausalLM`

The easiest way to load a Transformer model is to use the [`CachedCausalLM.from_pretrained`][hfppl.llms.CachedCausalLM.from_pretrained] static method, which accepts as input a HuggingFace model identifier. This loads the model's weights into memory, and also loads the appropriate tokenizer. The optional `auth_token` parameter can be provided if the model in question requires HuggingFace authorization (e.g., Meta's Llama 2 models).

## Use the LLM within your model via the `Transformer` distribution

## Use the LLM within your model via the `StatefulLM` class
Within a model, you can `sample` or `observe` from the [`Transformer`][hfppl.distributions.transformer.Transformer] distribution. It accepts as arguments a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] instance, as well as a list of integer token ids specifying the context. It returns a distribution over next tokens. The [`Transformer`][hfppl.distributions.transformer.Transformer] distirbution is stateless, and so your model will need to manually extend the context with newly sampled tokens.

## Use the LLM within your model via the `LMContext` class

Alternatively, you can initialize an [`LMContext`][hfppl.distributions.lmcontext.LMContext] object with a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] instance instance and a string-valued prompt. It maintains a growing context as state, and exposes a [`next_token`][hfppl.distributions.lmcontext.LMContext.next_token] distribution that, when sampled, observed, or intervened, grows the context. It also supports a form of 'sub-token' generation, via the [`mask_dist`][hfppl.distributions.lmcontext.LMContext.mask_dist] distribution.

## Create custom token distributions with `TokenCategorical`

## Create custom token distributions with `TokenCategorical`
You may also create a custom distribution over the vocabulary of a language model using the [`TokenCategorical`][hfppl.distributions.tokencategorical.TokenCategorical] distribution. It is parameterized by a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] instance, and an array of logits equal in length to the language model's vocabulary size.
This distribution is particularly useful as a proposal distribution; for example, a model might `sample` with `dist` set
to the LM's next token distribution, but with `proposal` set to a modified distribution that uses a heuristic to upweight
'good' tokens and downweight 'bad' ones.
21 changes: 9 additions & 12 deletions examples/hard_constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import string
import asyncio
from hfppl import Model, CachedCausalLM, Token, StatefulLM, smc_standard
from hfppl import Model, CachedCausalLM, Token, LMContext, smc_standard

import os
HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN']
Expand All @@ -21,18 +21,16 @@
class ConstraintModel(Model):
def __init__(self, prompt, max_tokens):
super().__init__()
self.lm = StatefulLM(LLM, prompt)
self.q = StatefulLM(LLM, prompt)
self.prompt_len = len(str(self.lm.s))
self.context = LMContext(LLM, prompt)
self.q = LMContext(LLM, prompt)
self.max_tokens = max_tokens


async def step(self):
# Which tokens are allowed?
mask = self.active_constraint_mask()

# Generate proposed token.
token = await self.sample(self.lm.next_token(),
token = await self.sample(self.context.next_token(),
proposal = await self.proposal(mask))

# Condition on constraint — a no-op since proposal already guarantees the constraint
Expand All @@ -41,24 +39,23 @@ async def step(self):
# Reduce number of max tokens remaining
self.max_tokens -= 1

#if self.max_tokens % 5 == 0:
print(str(self.lm.s)[self.prompt_len:])
print(f"{self.context}")

# Check if done
if token == LLM.tokenizer.eos_token_id or self.max_tokens == 0:
self.finish()
return

def active_constraint_mask(self):
string_so_far = str(self.lm.s)
string_so_far = str(self.context.s)
words = string_so_far.split()
last_word = words[-1] if len(words) > 0 else ""
return MASKS[min(5, len(last_word))]

async def proposal(self, mask):
string_so_far = str(self.lm.s)
string_so_far = str(self.context)

# Force the proposal StatefulLM to adhere to this mask
# Force the proposal LMContext to adhere to this mask
await self.intervene(self.q.mask_dist(mask), True)

# Return the proposal's modified next-token distribution
Expand All @@ -80,6 +77,6 @@ async def main():
constraint_model = ConstraintModel(prompt, 50)
particles = await smc_standard(constraint_model, 40)
for p in particles:
print(str(p.lm.s)[p.prompt_len:])
print(f"{p.context}")

asyncio.run(main())
6 changes: 3 additions & 3 deletions hfppl/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
* `LogCategorical(logits: array) -> int`
* `TokenCategorical(lm: hfppl.llms.CachedCausalLM, logits: array) -> hfppl.llms.Token`
* `Transformer(lm: hfppl.llms.CachedCausalLM) -> hfppl.llms.Token`
* `StatefulLM(lm: hfppl.llms.CachedCausalLM, prompt: list[int]).next_token() -> hfppl.llms.Token`
* `StatefulLM(lm: hfppl.llms.CachedCausalLM, prompt: list[int]).mask_dist(mask: set[int]) -> bool`
* `LMContext(lm: hfppl.llms.CachedCausalLM, prompt: list[int]).next_token() -> hfppl.llms.Token`
* `LMContext(lm: hfppl.llms.CachedCausalLM, prompt: list[int]).mask_dist(mask: set[int]) -> bool`
"""

from .distribution import Distribution
from .geometric import Geometric
from .logcategorical import LogCategorical
from .tokencategorical import TokenCategorical
from .transformer import Transformer
from .statefullm import StatefulLM
from .lmcontext import LMContext
from .bernoulli import Bernoulli
143 changes: 143 additions & 0 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from ..util import log_softmax, logsumexp
from .distribution import Distribution
from ..llms import Token, TokenSequence
import numpy as np
import copy

class LMNextToken(Distribution):

def __init__(self, ctx):
self.ctx = ctx

async def log_prob(self, x):
if isinstance(x, Token):
x = x.token_id

lp = self.ctx.next_token_logprobs[x]
self.ctx.s += x
updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.s.seq)
self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp)
self.ctx.model_mask = self.ctx.NO_MASK

return lp

async def sample(self):
probs = np.exp(self.ctx.next_token_logprobs)
token_id = np.random.choice(len(probs), p=(probs))
logprob = self.ctx.next_token_logprobs[token_id]
t = Token(self.ctx.lm, token_id, self.ctx.lm.tokenizer.convert_ids_to_tokens(token_id))
self.ctx.s += t
self.ctx.model_mask = self.ctx.NO_MASK
updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.s.seq)
self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp)
return t, logprob


class LMTokenMask(Distribution):
def __init__(self, ctx, mask):
self.ctx = ctx
self.mask = mask

async def sample(self):
newly_bad_tokens = [i for i in self.ctx.model_mask if i not in self.mask]
good_tokens = [i for i in self.ctx.model_mask if i in self.mask]
logprob_no_mask = logsumexp(self.ctx.next_token_logprobs[newly_bad_tokens])
logprob_yes_mask = np.log1p(-np.exp(logprob_no_mask))
decide_no_mask = np.random.rand() < np.exp(logprob_no_mask)
if decide_no_mask:
self.ctx.model_mask = self.ctx.model_mask - self.mask
self.ctx.next_token_logprobs[good_tokens] = float('-inf')
self.ctx.next_token_logprobs -= logprob_no_mask
return False, logprob_no_mask
else:
self.ctx.model_mask = self.ctx.model_mask.intersection(self.mask)
self.ctx.next_token_logprobs[newly_bad_tokens] = float('-inf')
self.ctx.next_token_logprobs -= logprob_yes_mask
return True, logprob_yes_mask

async def log_prob(self, v):
good_tokens = self.ctx.model_mask.intersection(self.mask) if v else self.ctx.model_mask - self.mask
bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens]
logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])
self.ctx.next_token_logprobs[bad_tokens] = float('-inf')
self.ctx.next_token_logprobs -= logprob_good
self.ctx.model_mask = good_tokens
return logprob_good


class LMContext:
"""Represents a generation-in-progress from a language model.
The state tracks two pieces of information:
* A sequence of tokens — the ever-growing context for the language model.
* A *current mask* — a set of tokens that have not yet been ruled out as the next token.
Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample
the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case
first letter, and only later deciding which upper-case or lower-case token to generate.
The state of a `LMContext` can be advanced in two ways:
1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token
to be added to the growing sequence of tokens. Supports auto-batching.
2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of
token ids). This changes the current mask.
Attributes:
lm (hfppl.llms.CachedCausalLM): the language model for which this is a context
s (hfppl.llms.TokenSequence): the underlying sequence of tokens, including prompt, in this context
next_token_logprobs (numpy.array): numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by `CachedCausalLM.next_token_logprobs`, these probabilities are rescaled for this `LMContext`'s temperature parameter, and for any active masks. This vector is managed by the `LMContext` object internally; do not mutate.
temp (float): temeprature for next-token distribution (0 < temp < float('inf'))
model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate.
show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`.
"""

def __init__(self, lm, prompt, temp=1.0):
"""Create a new `LMContext` with a given prompt and temperature.
Args:
lm (hfppl.llms.CachedCausalLM): the language model for which this is a context.
prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`.
temp (float): temeprature for next-token distribution (0 < temp < float('inf'))"""
self.lm = lm
self.s = TokenSequence(lm, prompt)
self.next_token_logprobs = log_softmax(lm.next_token_logprobs_unbatched(self.s.seq) / temp)
self.temp = temp
self.NO_MASK = set(range(len(self.lm.vocab)))
self.model_mask = self.NO_MASK
self.prompt_string_length = len(str(self.s))
self.show_prompt = False

def next_token(self):
"""Distribution over the next token.
Sampling or observing from this distribution advances the state of this `LMContext` instance."""
return LMNextToken(self)

def mask_dist(self, mask):
"""Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs
to the given mask.
Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that
the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from
the given mask.
Args:
mask: a `set(int)` specifying which token ids are included within the mask."""
return LMTokenMask(self, mask)

def __str__(self):
base = 0 if self.show_prompt else self.prompt_string_length
return str(self.s)[base:]

def __deepcopy__(self, memo):
cpy = type(self).__new__(type(self))

for k, v in self.__dict__.items():
if k in set(['lm', 'NO_MASK']):
setattr(cpy, k, v)
else:
setattr(cpy, k, copy.deepcopy(v, memo))

return cpy
Loading

0 comments on commit c7b0974

Please sign in to comment.