Skip to content

Commit

Permalink
implemented kl divergence and simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Markham committed Jan 18, 2024
1 parent 919dc49 commit 8b1b6b2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/cstrees/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Evaluate estimated CStrees."""
from itertools import product, pairwise, tee
from functools import reduce
import operator

from scipy.special import rel_entr
Expand All @@ -14,23 +15,30 @@ def kl_divergence(estimated: CStree, true: CStree) -> float:
See the `KL divergence
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_.
"""
factorized_outcomes = (range(card) for card in enumerate(true.cards))
factorized_outcomes = (range(card) for card in true.cards)
outcomes = product(*factorized_outcomes)

def _rel_entr_of_outcome(outcome):
nodes = (outcome[:idx] for idx in range(true.p + 1))
edges = pairwise(nodes)

def _prob_map(edge):
def _probs_map(edge):
est = estimated.tree[edge[0]][edge[1]]["cond_prob"]
tru = true.tree[edge[0]][edge[1]]["cond_prob"]
return est, tru

zipped_probs = map(_prob_map, edges)
estimated_probs, true_probs = tee(zipped_probs)
zipped_probs = map(_probs_map, edges)

estimated_prob_outcome = reduce(operator.mul, estimated_probs)
true_prob_outcome = reduce(operator.mul, true_probs)
return rel_entr(estimated_prob_outcome, true_prob_outcome)
def _probs_of_outcome(prev_pair, next_pair):
return prev_pair[0] * next_pair[0], prev_pair[1] * next_pair[1]

return sum(map(_rel_entr_of_outcome, outcomes)
est_prob_outcome, true_prob_outcome = reduce(_probs_of_outcome, zipped_probs)
return rel_entr(est_prob_outcome, true_prob_outcome)

return sum(map(_rel_entr_of_outcome, outcomes))


# because CStrees are created on the fly while sampling, computing KL
# divergence (or making prediction) may produce key error; can catch
# these errors and set prob of corresponding outcome to 0? or sample
# more? or some other way to generate full tree?
26 changes: 26 additions & 0 deletions src/cstrees/tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import random

import numpy as np

from cstrees import cstree as ct
from cstrees.evaluate import kl_divergence


def test_kl_divergence():
np.random.seed(22)
random.seed(22)

cards = [3, 2, 2, 3]

t = ct.sample_cstree(cards, max_cvars=2, prob_cvar=0.5, prop_nonsingleton=1)
t.sample_stage_parameters(alpha=2)

t.sample(1000)

assert kl_divergence(t, t) == 0

e = ct.sample_cstree(cards, max_cvars=2, prob_cvar=0.5, prop_nonsingleton=1)
e.sample_stage_parameters(alpha=2)

e.sample(1000)
assert kl_divergence(e, t) > 0

0 comments on commit 8b1b6b2

Please sign in to comment.