Skip to content

Commit

Permalink
expand use of PreorderTreeBuilder class for sampling and indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 30, 2024
1 parent 4de85e8 commit 2224675
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 24 deletions.
127 changes: 109 additions & 18 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ def convert(dag, newclass):

TreeBuilderNode = Any
class PreorderTreeBuilder:
"""Any class implementing a PreorderTreeBuilder interface can be used as a tree sample
constructor in :meth:`HistoryDag.fast_sample`. Subclasses implementing this interface may
implement an arbitrary constructor interface, as the user will be responsible for creating
instances to be used for sampling. In addition, subclasses must implement the following methods:
Methods:
add_node: This method must accept a :class:HistoryDagNode object ``dag_node`` and, optionally
a TreeBuilderNode instance ``parent``, representing the parent node of the node to be added,
and returns a TreeBuilderNode instance representing the added node in the sampled
tree. TreeBuilderNode can be any type which is convenient for the internal
implementation of the PreorderTreeBuilder subclass. This method can expect to be
called on nodes in a sampled tree in a pre-ordering. A parent node will always be
provided unless `dag_node` is the root node.
get_finished_tree: This method takes no arguments and returns the data defining the
sampled tree, after any necessary clean-up or final tree construction steps. Its
return value is the return value of :meth:`HistoryDag.fast_sample`.
"""
pass


Expand Down Expand Up @@ -152,30 +169,27 @@ def __init__(
self,
dag_type,
):
self.nodes = []
self.root_node = None
self.edges = []
self.dag_type = dag_type

def add_node(
self,
dag_node: HistoryDagNode,
parent: int = None,
) -> int:
new_node_idx = len(self.nodes)
parent: HistoryDagNode = None,
) -> HistoryDagNode:
new_node = dag_node.empty_copy()
self.nodes.append(new_node)
if parent is None:
assert new_node_idx == 0
assert self.root_node is None
self.root_node = new_node
else:
self.edges.append((parent, new_node_idx))
return new_node_idx
self.edges.append((parent, new_node))
return new_node

def get_finished_tree(self):
for parent_idx, child_idx in reversed(self.edges):
parent = self.nodes[parent_idx]
child = self.nodes[child_idx]
for parent, child in reversed(self.edges):
parent.add_edge(child)
return self.dag_type(self.nodes[0])
return self.dag_type(self.root_node)



Expand Down Expand Up @@ -375,13 +389,45 @@ def __getitem__(self, key) -> "HistoryDag":
key = length + key
if not (key >= 0 and key < length):
raise IndexError
return self.__class__(self.dagroot._get_subhistory_by_subid(key))
builder = PreorderHistoryBuilder(type(self))
self.dagroot._get_subhistory_by_subid(key, builder)
return builder.get_finished_tree()
else:
raise TypeError(
f"History DAG indices must be integers or utils.HistoryDagFilter"
f" objects, not {type(key)}"
)

def get_histories_by_index(self, key_iterator, tree_builder_func=None):
"""Retrieving a history by index is slow, since each retrieval requires
running the ``trim_optimal_weight`` method on the entire DAG to populate
node counts. This method instead runs that method a single time and
yields a history for each index yielded by ``key_iterator``.
Args:
key_iterator: An iterator on desired history indices. May be consumable, as
it will only be used once.
tree_builder_func: A function accepting an index and returning a
:class:`PreorderTreeBuilder` instance to be used to build the history
with that index. If None (default), then tree-shaped HistoryDag objects
will be yielded using :class:`PreorderHistoryBuilder`."""
if tree_builder_func is None:

def tree_builder_func(idx):
return PreorderHistoryBuilder(type(self))

length = self.count_histories()

for key in key_iterator:
if key < 0:
key = length + key
if not (key >= 0 and key < length):
raise IndexError(f"Invalid index {key} in DAG containing {length} histories")
builder = tree_builder_func(key)
self.dagroot._get_subhistory_by_subid(key, builder)
yield builder.get_finished_tree()


def get_label_type(self) -> type:
"""Return the type for labels on this dag's nodes."""
return type(next(self.dagroot.children()).label)
Expand Down Expand Up @@ -766,15 +812,28 @@ def find_node(
def fast_sample(
self,
tree_builder: PreorderTreeBuilder=None,
edge_selector=lambda e: True,
log_probabilities=False,
):
"""This is a non-recursive alternative to :meth:`HistoryDag.sample`, which is likely
to be slower on small DAGs, but may allow significant optimizations on large DAGs, or
in the case that the data format being sampled is something other than a tree-shaped
HistoryDag object.
This method does not provide an edge_selector argument like :meth:`HistoryDag.sample`.
Instead, any masking of edges should be done prior to sampling using the :meth:`HistoryDag.set_sample_mask`
method, or by modifying the arguments to :meth:`HistoryDag.probability_annotate`.
Args:
tree_builder: a PreorderTreeBuilder instance to handle construction of the sampled tree.
log_probabilities: Whether edge probabilities annotated on this DAG (using, for example,
:meth:`HistoryDag.probability_annotate`) are on a log-scale."""
if tree_builder is None:
tree_builder = PreorderHistoryBuilder(type(self))

def get_sampled_children(node):
for clade, eset in node.clades.items():
mask = [edge_selector((node, target)) for target in eset.targets]
sampled_target, _ = eset.sample(mask=mask, log_probabilities=log_probabilities)
sampled_target, _ = eset.sample(log_probabilities=log_probabilities)
yield sampled_target

node_queue = [(self.dagroot, tree_builder.add_node(self.dagroot))]
Expand All @@ -795,10 +854,13 @@ def sample(
DAG containing the root and all leaf nodes) For reproducibility, set
``random.seed`` before sampling.
When there is an option, edges pointing to nodes on which `selection_func` is True
When there is an option, edges pointing to nodes on which `edge_selector` is True
will always be chosen.
Returns a new HistoryDag object.
To use the more general sampling pattern which allows an arbitrary PreorderTreeBuilder
object, use :meth:`HistoryDag.fast_sample` instead.
"""
return self.__class__(
self.dagroot._sample(
Expand Down Expand Up @@ -1287,7 +1349,7 @@ def to_ete(
tree_builder = EteTreeBuilder(name_func=name_func, features=features, feature_funcs=feature_funcs)
nodes_to_process = [(self.dagroot, tree_builder.add_node(self.dagroot))]
while len(nodes_to_process) > 0:
node, node_repr = nodes_to_proces.pop()
node, node_repr = nodes_to_process.pop()
for child in node.children():
nodes_to_process.append((child, tree_builder.add_node(child, parent=node_repr)))

Expand Down Expand Up @@ -3047,6 +3109,35 @@ def edge_probabilities(

return {key: aggregate_func(val) for key, val in edge_probabilities.items()}

def set_sample_mask(self, edge_selector, log_probabilities=False):
"""Zero out edge weights for masked edges before calling :meth:`HistoryDag.fast_sample`.
This should be equivalent to passing the same edge_selector function to :meth:`HistoryDag.sample`.
Args:
edge_selector: A function accepting an edge (a tuple of HistoryDagNode objects) and
returning True of False. An edge marked False will be ineligible for sampling, unless
all other edges in the same edge set are also marked False.
log_probabilities: Since the mask is applied by modifying edge probabilities, one must specify
whether those probabilities are on a log scale.
Take care to verify that you shouldn't instead use :meth:`HistoryDag.probability_annotate` with
a choice of ``edge_weight_func`` that takes into account the masking preferences."""

if log_probabilities:
mask_value = float("-inf")
else:
mask_value = 0

for node in self.preorder():
for clade, eset in node.clades.items():
mask = tuple(edge_selector((node, target)) for target in eset.targets)
# If all mask values are false, then skip modifying probs.
if any(mask):
for i, val in enumerate(mask):
if not val:
eset.probs[i] = mask_value


def probability_annotate(
self,
edge_weight_func,
Expand Down
10 changes: 4 additions & 6 deletions historydag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def add_edge(
prob_norm=prob_norm,
)

def _get_subhistory_by_subid(self, subid: int) -> "HistoryDagNode":
def _get_subhistory_by_subid(self, subid: int, tree_builder, parent_repr=None) -> "HistoryDagNode":
r"""Returns the subtree below the current HistoryDagNode corresponding
to the given index."""
this_repr = tree_builder.add_node(self, parent=parent_repr)
if self.is_leaf(): # base case - the node is a leaf
return self
return
else:
history = self.empty_copy()

Expand All @@ -205,13 +206,10 @@ def _get_subhistory_by_subid(self, subid: int) -> "HistoryDagNode":
curr_index = curr_index - child._dp_data
else:
# add this edge to the tree somehow
history.clades[clade].add_to_edgeset(
child._get_subhistory_by_subid(curr_index)
)
child._get_subhistory_by_subid(curr_index, tree_builder, parent_repr=this_repr)
break

subid = subid / num_subtrees
return history

def remove_edge_by_clade_and_id(self, target: "HistoryDagNode", clade: frozenset):
key: frozenset
Expand Down
28 changes: 28 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,34 @@ def test_from_nodes():
assert wc == ndag.weight_count()


def test_fast_sample_with_node():
random.seed(1)
dag = dags[-1]
dag.make_uniform()
node_to_count = dag.count_nodes()
min_count = min(node_to_count.values())
least_supported_nodes = [
node for node, val in node_to_count.items() if val == min_count
]
for node in least_supported_nodes:
mask_true = dag.nodes_above_node(node)
def edge_selector(edge):
return edge[-1] in mask_true
dag.make_uniform()
dag.set_sample_mask(edge_selector)
tree_samples = [dag.fast_sample() for _ in range(min_count * 7)]
tree_samples[0]._check_valid()
tree_newicks = {tree.to_newick() for tree in tree_samples}
# We sampled all trees possible containing the node
assert len(tree_newicks) == min_count
# All trees sampled contained the node
assert all(node in set(tree.preorder()) for tree in tree_samples)
# # trees containing the node were sampled uniformly
# # (This is slow but seems to work)
# norms, avg = normalize_counts(Counter(tree.to_newick() for tree in tree_samples))
# print(norms)
# assert all(is_close(norm, avg) for norm in norms)

def test_sample_with_node():
random.seed(1)
dag = dags[-1]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_historydag.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ def test_sample():
sample.to_graphviz(namedict=namedict)


def test_fast_sample():
newicks = ["((a, b)b, c)c;", "((a, b)c, c)c;", "((a, b)a, c)c;", "((a, b)r, c)r;"]
newicks = ["((1, 2)2, 3)3;", "((1, 2)3, 3)3;", "((1, 2)1, 3)3;", "((1, 2)4, 3)4;"]
namedict = {(str(x),): x for x in range(5)}
dag = history_dag_from_newicks(newicks, ["name"])
for i in range(10):
assert dag.fast_sample().is_history()
sample = dag.fast_sample()
sample._check_valid()
sample.to_graphviz(namedict=namedict)


def test_unifurcation():
# Make sure that unifurcations are handled correctly
# First make sure the call works when the problem is fixed:
Expand Down

0 comments on commit 2224675

Please sign in to comment.