Skip to content

Commit

Permalink
Merge pull request #217 from probcomp/092524-emilyaf-model7-gibbs
Browse files Browse the repository at this point in the history
Add transition_reference method to GenDB.
  • Loading branch information
emilyfertig authored Oct 2, 2024
2 parents fca1b5b + 93755f6 commit f02a443
Show file tree
Hide file tree
Showing 21 changed files with 1,064 additions and 297 deletions.
78 changes: 59 additions & 19 deletions cxx/clean_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class CleanRelation : public Relation<T> {
}
T_items z = get_cluster_assignment(items);
if (!clusters.contains(z)) {
assert(prng != nullptr);
clusters[z] = make_new_distribution(prng);
}
return z;
Expand All @@ -116,29 +117,44 @@ class CleanRelation : public Relation<T> {
return value;
}

void unincorporate(const T_items& items) {
assert(data.contains(items));
ValueType value = data.at(items);
std::vector<int> z = get_cluster_assignment(items);
clusters.at(z)->unincorporate(value);
if (clusters.at(z)->N == 0) {
delete clusters.at(z);
clusters.erase(z);
}
// TODO: add a test.
void cleanup_data(const T_items& items) {
for (int i = 0; i < std::ssize(domains); ++i) {
const std::string& name = domains[i]->name;
if (data_r.at(name).contains(items[i])) {
data_r.at(name).at(items[i]).erase(items);
if (data_r.at(name).at(items[i]).size() == 0) {
// It's safe to unincorporate this element since no other data point
// refers to it.
data_r.at(name).erase(items[i]);
}
}
}
data.erase(items);
}

// TODO: add a test.
void cleanup_clusters() {
for (auto it = clusters.cbegin(); it != clusters.cend();) {
if (it->second->N == 0) {
delete it->second;
clusters.erase(it++);
} else {
++it;
}
}
}

void unincorporate(const T_items& items) {
assert(data.contains(items));
ValueType value = data.at(items);
std::vector<int> z = get_cluster_assignment(items);
clusters.at(z)->unincorporate(value);
if (clusters.at(z)->N == 0) {
delete clusters.at(z);
clusters.erase(z);
}
cleanup_data(items);
}

// incorporate_to_cluster and unincorporate_from_cluster should be used with
// extreme care, since they mutate the clusters only and not the relation. In
// particular, for every call to unincorporate_from_cluster, there must be a
Expand Down Expand Up @@ -172,24 +188,48 @@ class CleanRelation : public Relation<T> {
}

bool clusters_contains(const T_items& items) const {
std::vector<int> z = get_cluster_assignment(items);
assert(items.size() == domains.size());
std::vector<int> z;
z.reserve(items.size());
for (int i = 0; i < std::ssize(domains); ++i) {
if (!domains[i]->items.contains(items[i])) {
return false;
}
z.push_back(domains[i]->get_cluster_assignment(items[i]));
}
return clusters.contains(z);
}

double cluster_or_prior_logp(std::mt19937* prng, const T_items& items,
const ValueType& value) const {
if (clusters.contains(items)) {
return clusters.at(items)->logp(value);
}
double prior_logp(std::mt19937* prng, const ValueType& value) const {
assert(prng != nullptr);
Distribution<ValueType>* prior = make_new_distribution(prng);
double prior_logp = prior->logp(value);
delete prior;
return prior_logp;
}

double cluster_or_prior_logp_from_items(std::mt19937* prng,
const T_items& items,
const ValueType& value) const {
if (clusters_contains(items)) {
T_items z = get_cluster_assignment(items);
return clusters.at(z)->logp(value);
}
return prior_logp(prng, value);
}

double cluster_or_prior_logp(std::mt19937* prng, const std::vector<int>& z,
const ValueType& value) const {
if (clusters.contains(z)) {
return clusters.at(z)->logp(value);
}
return prior_logp(prng, value);
}

ValueType sample_at_items(std::mt19937* prng, const T_items& items) const {
if (clusters.contains(items)) {
return clusters.at(items)->sample(prng);
if (clusters_contains(items)) {
T_items z = get_cluster_assignment(items);
return clusters.at(z)->sample(prng);
}
Distribution<ValueType>* prior = make_new_distribution(prng);
ValueType prior_sample = prior->sample(prng);
Expand Down
6 changes: 5 additions & 1 deletion cxx/distributions/base.hh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include<random>
#include <cassert>
#include <random>

template <typename T>
class Distribution {
Expand All @@ -18,6 +19,9 @@ class Distribution {
// have been previously passed to incorporate().
virtual void unincorporate(const T& x) {
incorporate(x, -1.0);
// TODO: Debug why this fails sometimes, e.g. for the bigram_string
// emission.
// assert(N >= 0);
}

// The log probability of x according to the posterior predictive
Expand Down
5 changes: 3 additions & 2 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ std::vector<size_t> Bigram::string_to_indices(const std::string& str) const {

void Bigram::incorporate(const std::string& x, double weight) {
if ((max_length > 0) && (x.length() > max_length)) {
printf("String %s has length %ld, but max length is %ld.\n",
x.c_str(), x.length(), max_length);
printf("String %s has length %ld, but max length is %ld.\n", x.c_str(),
x.length(), max_length);
std::exit(1);
}
const std::vector<size_t> indices = string_to_indices(x);
Expand Down Expand Up @@ -95,6 +95,7 @@ std::string Bigram::sample(std::mt19937* prng) {
transition_dists[current_ind].incorporate(next_ind);
current_ind = next_ind;
}
N += 1; // Correct for calling unincorporate below.
unincorporate(sampled_string);
return sampled_string;
}
Expand Down
13 changes: 5 additions & 8 deletions cxx/distributions/bigram.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@ class Bigram : public Distribution<std::string> {
std::vector<size_t> string_to_indices(const std::string& str) const;

public:
double alpha = 1; // hyperparameter for all transition distributions.
double alpha = 1; // hyperparameter for all transition distributions.
size_t max_length = 0; // 0 means no maximum length
char min_char; // Character with smallest ASCII value.
char max_char; // Character with largest ASCII value.
char min_char; // Character with smallest ASCII value.
char max_char; // Character with largest ASCII value.
size_t num_chars;
mutable std::vector<DirichletCategorical> transition_dists;

Bigram(size_t _max_length = 80,
char _min_char = ' ',
char _max_char = '~'):
max_length(_max_length), min_char(_min_char), max_char(_max_char)
{
Bigram(size_t _max_length = 80, char _min_char = ' ', char _max_char = '~')
: max_length(_max_length), min_char(_min_char), max_char(_max_char) {
num_chars = max_char - min_char + 1;
const size_t total_chars = num_chars + 1; // Include a start/stop symbol.

Expand Down
6 changes: 6 additions & 0 deletions cxx/distributions/bigram_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ BOOST_AUTO_TEST_CASE(test_simple) {

bg.incorporate("fractions", 1.23);
BOOST_TEST(bg.N == 2.23);

// Ensure that `sample` doesn't change N.
std::mt19937 prng;
double init_N = bg.N;
bg.sample(&prng);
BOOST_TEST(init_N == bg.N);
}

BOOST_AUTO_TEST_CASE(test_max_length) {
Expand Down
2 changes: 2 additions & 0 deletions cxx/distributions/crp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ int CRP::sample(std::mt19937* prng) {
return items[idx];
}

double CRP::logp_new_table() const { return log(alpha) - log(N + alpha); }

double CRP::logp(int table) const {
auto dist = tables_weights();
if (!dist.contains(table)) {
Expand Down
2 changes: 2 additions & 0 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CRP {

int sample(std::mt19937* prng);

double logp_new_table() const;

double logp(int table) const;

double logp_score() const;
Expand Down
3 changes: 3 additions & 0 deletions cxx/domain.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Domain {
if (items.contains(item)) {
assert(table == -1);
} else {
if (table == -1) {
assert(prng != nullptr);
}
items.insert(item);
int t = 0 <= table ? table : crp.sample(prng);
crp.incorporate(item, t);
Expand Down
Loading

0 comments on commit f02a443

Please sign in to comment.