Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/probcomp/hfppl
Browse files Browse the repository at this point in the history
  • Loading branch information
postylem committed Nov 7, 2023
2 parents c7b0974 + 003cde8 commit c32623f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cd hfppl
pip install .
```

You can then run an example. The first time you run it, the example may ask to downlaod model weights from the HuggingFace model repository.
You can then run an example. The first time you run it, the example may ask to downlaod model weights from the HuggingFace model repository.

```
python examples/hard_constraints.py
Expand All @@ -39,34 +39,34 @@ class MyModel(Model):
# The __init__ method is used to process arguments
# and initialize instance variables.
def __init__(self, lm, prompt, forbidden_letter):

# Always call the superclass's __init__.
super().__init__()

# A stateful context object for the LLM, initialized with the prompt
self.context = LMContext(lm, prompt)

# The forbidden letter
self.forbidden_tokens = [i for (i, v) in enumerate(lm.vocab)
if forbidden_letter in v]

# The step method is used to perform a single 'step' of generation.
# This might be a single token, a single phrase, or any other division.
# Here, we generate one token at a time.
async def step(self):
# Sample a token from the LLM -- automatically extends `self.context`.
# We use `await` so that LLaMPPL can automatically batch language model calls.
token = await self.sample(self.context.next_token(),
token = await self.sample(self.context.next_token(),
proposal=self.proposal())

# Condition on the token not having the forbidden letter
self.condition(token.token_id not in self.forbidden_tokens)

# Check for EOS or end of sentence
if token.token_id == self.lm.tokenizer.eos_token_id or str(token) in ['.', '!', '?']:
if token.token_id == self.context.lm.tokenizer.eos_token_id or str(token) in ['.', '!', '?']:
# Finish generation
self.finish()

# Helper method to define a custom proposal
def proposal(self):
logits = self.context.next_token_logprobs.copy()
Expand Down Expand Up @@ -94,7 +94,7 @@ model = MyModel(lm, "The weather today is expected to be", "e")
particles = asyncio.run(smc_steer(model, 5, 3)) # number of particles N, and beam factor K
```

Each returned particle is an instance of the `MyModel` class that has been `step`-ped to completion.
Each returned particle is an instance of the `MyModel` class that has been `step`-ped to completion.
The generated strings can be printed along with the particle weights:

```python
Expand Down

0 comments on commit c32623f

Please sign in to comment.