Skip to content

Commit

Permalink
Merge pull request #14 from probcomp/alexlew-viz-071524
Browse files Browse the repository at this point in the history
First pass at integrating Jacob Hoover's and Maddy Bowers's visualization code
  • Loading branch information
alex-lew authored Jul 17, 2024
2 parents 06b8ee7 + e7e8382 commit 4921bfe
Show file tree
Hide file tree
Showing 11 changed files with 3,241 additions and 668 deletions.
15 changes: 15 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Visualization

We provide a Web interface for visualizing the execution of a sequential Monte Carlo algorithm,
based on contributions from Maddy Bowers and Jacob Hoover.

First, update your model to support visualization by implementing the [`string_for_serialization`](hfppl.modeling.Model.string_for_serialization) method.
Return a string that summarizes the particle's current state.

To run the interface, change to the `html` directory and run `python -m http.server`. This will start serving
the files in the `html` directory at localhost:8000. (If you are SSH-ing onto a remote machine, you may need
port forwarding. Visual Studio Code automatically handles this for some ports, including 8000.)
Then, when calling [`smc_standard`](hfppl.inference.smc_standard), set `visualization_dir`
to the path to the `html` directory. A JSON record of the run will automatically be saved
to that directory, and a URL will be printed to the console (`http://localhost:8000/smc.html?path=$json_file`).

31 changes: 24 additions & 7 deletions examples/haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def count_syllables(word, unknown_word_syllables=100):
if "HF_AUTH_TOKEN" in os.environ:
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
LLM = CachedCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN
"meta-llama/Meta-Llama-3-8B", auth_token=HF_AUTH_TOKEN
)
else:
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
Expand Down Expand Up @@ -83,7 +83,7 @@ def count_syllables(word, unknown_word_syllables=100):
LLM.cache_kv(LLM.tokenizer.encode(poem_prompt))

# Useful constants
NEWLINE_TOKEN, EOS_TOKEN = 13, LLM.tokenizer.eos_token_id
NEWLINE_TOKEN, EOS_TOKEN = LLM.vocab.index("\n"), LLM.tokenizer.eos_token_id


# LLaMPPL model
Expand All @@ -93,8 +93,11 @@ def __init__(self, prompt, syllable_pattern=[5, 7, 5]):
super().__init__()
self.context = LMContext(LLM, prompt, 0.7)
self.syllable_pattern = syllable_pattern
self.previous_string = str(self.context)

async def step(self):
self.previous_string = str(self.context)

# Get the number of syllables required in the next line
syllables_remaining = self.syllable_pattern.pop(0)

Expand Down Expand Up @@ -122,12 +125,26 @@ async def step(self):
# Print current result
print(str(self.context))

def string_for_serialization(self):
# Replace newlines with slashes in str(self.context)
s = (
self.previous_string
+ "<<<"
+ str(self.context)[len(self.previous_string) :]
+ ">>>"
)
return s.replace("\n", "/")


# Run inference
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
particles = asyncio.run(smc_standard(Haiku(poem_prompt, SYLLABLES_PER_LINE), 120))
particles = asyncio.run(
smc_standard(
Haiku(poem_prompt, SYLLABLES_PER_LINE), 20, 0.5, "html", "results/haiku.json"
)
)

print("--------")
for i, particle in enumerate(particles):
print(f"Poem {i} (weight {particle.weight}):")
print(f"{particle.context}")
# print("--------")
# for i, particle in enumerate(particles):
# print(f"Poem {i} (weight {particle.weight}):")
# print(f"{particle.context}")
23 changes: 16 additions & 7 deletions examples/hard_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# Load the language model.
# Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 2,
# pass your HuggingFace API key as the optional `auth_token` argument:
# LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN)
LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
LLM = CachedCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B", auth_token=HF_AUTH_TOKEN
)
# LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
LLM.batch_size = 40

Expand All @@ -35,13 +37,11 @@ def __init__(self, prompt, max_tokens):
self.context = LMContext(LLM, prompt)
self.max_tokens = max_tokens

async def step(self):
# Which tokens are allowed?
async def start(self):
mask = self.active_constraint_mask()

# Condition on next token being from mask
await self.observe(self.context.mask_dist(mask), True)

async def step(self):
# Generate proposed token.
token = await self.sample(self.context.next_token())

Expand All @@ -55,12 +55,19 @@ async def step(self):
self.finish()
return

# Observe that next token follows the constraint.
mask = self.active_constraint_mask()
await self.observe(self.context.mask_dist(mask), True)

def active_constraint_mask(self):
string_so_far = str(self.context)
words = string_so_far.split()
last_word = words[-1] if len(words) > 0 else ""
return MASKS[min(5, len(last_word))]

def string_for_serialization(self):
return f"{self.context}"


# From Politico.com
prompt = """3 things to watch …
Expand All @@ -76,7 +83,9 @@ def active_constraint_mask(self):

async def main():
constraint_model = ConstraintModel(prompt, 50)
particles = await smc_standard(constraint_model, 40)
particles = await smc_standard(
constraint_model, 20, 0.5, "html", "results/output.json"
)
for p in particles:
print(f"{p.context}")

Expand Down
71 changes: 71 additions & 0 deletions hfppl/inference/smc_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import json
import numpy as np


class SMCRecord:

def __init__(self, n):
self.history = []
self.most_recent_weights = [0.0 for _ in range(n)]
self.step_num = 1

def prepare_string(self, s):
# If the string doesn't have <<< and >>>, prepend <<<>>> at the front.
if "<<<" not in s and ">>>" not in s:
return f"<<<>>>{s}"
return s

def particle_dict(self, particles):
return [
{
"contents": self.prepare_string(p.string_for_serialization()),
"logweight": (
"-Infinity" if p.weight == float("-inf") else str(float(p.weight))
),
"weight_incr": str(
float(p.weight) - float(self.most_recent_weights[i])
),
}
for (i, p) in enumerate(particles)
]

def add_init(self, particles):
self.history.append(
{
"step": self.step_num,
"mode": "init",
"particles": self.particle_dict(particles),
}
)
self.most_recent_weights = [p.weight for p in particles]

def add_smc_step(self, particles):
self.step_num += 1
self.history.append(
{
"step": self.step_num,
"mode": "smc_step",
"particles": self.particle_dict(particles),
}
)
self.most_recent_weights = [p.weight for p in particles]

def add_resample(self, ancestor_indices, particles):
self.step_num += 1
self.most_recent_weights = [
self.most_recent_weights[i] for i in ancestor_indices
]

self.history.append(
{
"mode": "resample",
"step": self.step_num,
"ancestors": [int(a) for a in ancestor_indices],
"particles": self.particle_dict(particles),
}
)

self.most_recent_weights = [p.weight for p in particles]

def to_json(self):
return json.dumps(self.history)
62 changes: 56 additions & 6 deletions hfppl/inference/smc_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,48 @@
from ..util import logsumexp
import numpy as np
import asyncio
from .smc_record import SMCRecord
from datetime import datetime


async def smc_standard(model, n_particles, ess_threshold=0.5):
async def smc_standard(
model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None
):
"""
Standard sequential Monte Carlo algorithm with multinomial resampling.
Args:
model (hfppl.modeling.Model): The model to perform inference on.
n_particles (int): Number of particles to execute concurrently.
ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`.
visualization_dir (str): Path to the directory where the visualization server is running.
json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided.
Returns:
particles (list[hfppl.modeling.Model]): The completed particles after inference.
"""
particles = [copy.deepcopy(model) for _ in range(n_particles)]
weights = [0.0 for _ in range(n_particles)]
await asyncio.gather(*[p.start() for p in particles])
record = visualization_dir is not None or json_file is not None
history = SMCRecord(n_particles) if record else None

ancestor_indices = list(range(n_particles))
did_resample = False
while any(map(lambda p: not p.done_stepping(), particles)):
# Step each particle
for p in particles:
p.untwist()
await asyncio.gather(*[p.step() for p in particles if not p.done_stepping()])

# Record history
if record:
if len(history.history) == 0:
history.add_init(particles)
elif did_resample:
history.add_resample(ancestor_indices, particles)
else:
history.add_smc_step(particles)

# Normalize weights
W = np.array([p.weight for p in particles])
w_sum = logsumexp(W)
Expand All @@ -36,14 +55,45 @@ async def smc_standard(model, n_particles, ess_threshold=0.5):
):
# Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now
probs = np.exp(normalized_weights)
particles = [
copy.deepcopy(
particles[np.random.choice(range(len(particles)), p=probs)]
)
ancestor_indices = [
np.random.choice(range(len(particles)), p=probs)
for _ in range(n_particles)
]

if record:
# Sort the ancestor indices
ancestor_indices.sort()

particles = [copy.deepcopy(particles[i]) for i in ancestor_indices]
avg_weight = w_sum - np.log(n_particles)
for p in particles:
p.weight = avg_weight

did_resample = True
else:
did_resample = False

if record:
# Figure out path to save JSON.
if visualization_dir is None:
json_path = json_file
else:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
json_relative = (
json_file
if json_file is not None
else f"{model.__class__.__name__}-{timestamp}.json"
)
json_path = f"{visualization_dir}/{json_file}"

# Save JSON
with open(json_path, "w") as f:
f.write(history.to_json())

# Web path is the part of the path after the html directory
if visualization_dir is not None:
print(f"Visualize at http://localhost:8000/smc.html?path={json_relative}")
else:
print(f"Saved record to {json_path}")

return particles
4 changes: 1 addition & 3 deletions hfppl/inference/smc_steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ async def smc_steer(model, n_particles, n_beam):
"""
# Create n_particles copies of the model
particles = [copy.deepcopy(model) for _ in range(n_particles)]

for particle in particles:
particle.start() # TODO: allow to be async?
await asyncio.gather(*[p.start() for p in particles])

while any(map(lambda p: not p.done_stepping(), particles)):
# Count the number of finished particles
Expand Down
10 changes: 9 additions & 1 deletion hfppl/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def step(self):
def __str__(self):
return "Particle"

def start(self):
async def start(self):
pass

def score(self, score):
Expand Down Expand Up @@ -256,3 +256,11 @@ async def sample(self, dist, proposal=None):

async def call(self, submodel):
return await submodel.run_with_parent(self)

def string_for_serialization(self):
"""Return a string representation of the particle for serialization purposes.
Returns:
str: a string representation of the particle.
"""
return str(self)
1 change: 1 addition & 0 deletions html/results/output.json

Large diffs are not rendered by default.

Loading

0 comments on commit 4921bfe

Please sign in to comment.