Skip to content

Commit

Permalink
format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 30, 2024
1 parent 2224675 commit bab151d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 38 deletions.
69 changes: 38 additions & 31 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ def convert(dag, newclass):
# Preorder tree creation class

TreeBuilderNode = Any


class PreorderTreeBuilder:
"""Any class implementing a PreorderTreeBuilder interface can be used as a tree sample
constructor in :meth:`HistoryDag.fast_sample`. Subclasses implementing this interface may
implement an arbitrary constructor interface, as the user will be responsible for creating
instances to be used for sampling. In addition, subclasses must implement the following methods:
"""Any class implementing a PreorderTreeBuilder interface can be used as a
tree sample constructor in :meth:`HistoryDag.fast_sample`. Subclasses
implementing this interface may implement an arbitrary constructor
interface, as the user will be responsible for creating instances to be
used for sampling. In addition, subclasses must implement the following
methods:
Methods:
add_node: This method must accept a :class:HistoryDagNode object ``dag_node`` and, optionally
Expand All @@ -119,11 +123,11 @@ class PreorderTreeBuilder:
sampled tree, after any necessary clean-up or final tree construction steps. Its
return value is the return value of :meth:`HistoryDag.fast_sample`.
"""

pass


class EteTreeBuilder(PreorderTreeBuilder):

def __init__(
self,
name_func: Callable[[HistoryDagNode], str] = lambda n: "unnamed",
Expand All @@ -138,6 +142,7 @@ def __init__(

def feature_func(node):
return getattr(node.label, feature)

self.feature_funcs[feature] = feature_func

self.feature_funcs = tuple(self.feature_funcs.items())
Expand All @@ -164,6 +169,7 @@ def add_node(
def get_finished_tree(self):
return self.treeroot


class PreorderHistoryBuilder(PreorderTreeBuilder):
def __init__(
self,
Expand Down Expand Up @@ -191,7 +197,6 @@ def get_finished_tree(self):
parent.add_edge(child)
return self.dag_type(self.root_node)



class HistoryDag:
r"""An object to represent a collection of internally labeled trees. A
Expand Down Expand Up @@ -400,17 +405,18 @@ def __getitem__(self, key) -> "HistoryDag":

def get_histories_by_index(self, key_iterator, tree_builder_func=None):
"""Retrieving a history by index is slow, since each retrieval requires
running the ``trim_optimal_weight`` method on the entire DAG to populate
node counts. This method instead runs that method a single time and
yields a history for each index yielded by ``key_iterator``.
running the ``trim_optimal_weight`` method on the entire DAG to
populate node counts. This method instead runs that method a single
time and yields a history for each index yielded by ``key_iterator``.
Args:
key_iterator: An iterator on desired history indices. May be consumable, as
it will only be used once.
tree_builder_func: A function accepting an index and returning a
:class:`PreorderTreeBuilder` instance to be used to build the history
with that index. If None (default), then tree-shaped HistoryDag objects
will be yielded using :class:`PreorderHistoryBuilder`."""
will be yielded using :class:`PreorderHistoryBuilder`.
"""
if tree_builder_func is None:

def tree_builder_func(idx):
Expand All @@ -422,12 +428,13 @@ def tree_builder_func(idx):
if key < 0:
key = length + key
if not (key >= 0 and key < length):
raise IndexError(f"Invalid index {key} in DAG containing {length} histories")
raise IndexError(
f"Invalid index {key} in DAG containing {length} histories"
)
builder = tree_builder_func(key)
self.dagroot._get_subhistory_by_subid(key, builder)
yield builder.get_finished_tree()


def get_label_type(self) -> type:
"""Return the type for labels on this dag's nodes."""
return type(next(self.dagroot.children()).label)
Expand Down Expand Up @@ -811,14 +818,13 @@ def find_node(

def fast_sample(
self,
tree_builder: PreorderTreeBuilder=None,
tree_builder: PreorderTreeBuilder = None,
log_probabilities=False,
):
"""This is a non-recursive alternative to :meth:`HistoryDag.sample`, which is likely
to be slower on small DAGs, but may allow significant optimizations on large DAGs, or
in the case that the data format being sampled is something other than a tree-shaped
HistoryDag object.
"""This is a non-recursive alternative to :meth:`HistoryDag.sample`,
which is likely to be slower on small DAGs, but may allow significant
optimizations on large DAGs, or in the case that the data format being
sampled is something other than a tree-shaped HistoryDag object.
This method does not provide an edge_selector argument like :meth:`HistoryDag.sample`.
Instead, any masking of edges should be done prior to sampling using the :meth:`HistoryDag.set_sample_mask`
Expand All @@ -827,7 +833,8 @@ def fast_sample(
Args:
tree_builder: a PreorderTreeBuilder instance to handle construction of the sampled tree.
log_probabilities: Whether edge probabilities annotated on this DAG (using, for example,
:meth:`HistoryDag.probability_annotate`) are on a log-scale."""
:meth:`HistoryDag.probability_annotate`) are on a log-scale.
"""
if tree_builder is None:
tree_builder = PreorderHistoryBuilder(type(self))

Expand All @@ -846,7 +853,6 @@ def get_sampled_children(node):

return tree_builder.get_finished_tree()


def sample(
self, edge_selector=lambda e: True, log_probabilities=False
) -> "HistoryDag":
Expand Down Expand Up @@ -1341,17 +1347,18 @@ def to_ete(
"""
# First build a dictionary of ete3 nodes keyed by HDagNodes.
if features is None:
features = list(
list(self.dagroot.children())[0].label._asdict().keys()
)
features = list(list(self.dagroot.children())[0].label._asdict().keys())


tree_builder = EteTreeBuilder(name_func=name_func, features=features, feature_funcs=feature_funcs)
tree_builder = EteTreeBuilder(
name_func=name_func, features=features, feature_funcs=feature_funcs
)
nodes_to_process = [(self.dagroot, tree_builder.add_node(self.dagroot))]
while len(nodes_to_process) > 0:
node, node_repr = nodes_to_process.pop()
for child in node.children():
nodes_to_process.append((child, tree_builder.add_node(child, parent=node_repr)))
nodes_to_process.append(
(child, tree_builder.add_node(child, parent=node_repr))
)

return tree_builder.get_finished_tree()

Expand Down Expand Up @@ -3110,8 +3117,9 @@ def edge_probabilities(
return {key: aggregate_func(val) for key, val in edge_probabilities.items()}

def set_sample_mask(self, edge_selector, log_probabilities=False):
"""Zero out edge weights for masked edges before calling :meth:`HistoryDag.fast_sample`.
This should be equivalent to passing the same edge_selector function to :meth:`HistoryDag.sample`.
"""Zero out edge weights for masked edges before calling
:meth:`HistoryDag.fast_sample`. This should be equivalent to passing
the same edge_selector function to :meth:`HistoryDag.sample`.
Args:
edge_selector: A function accepting an edge (a tuple of HistoryDagNode objects) and
Expand All @@ -3121,7 +3129,8 @@ def set_sample_mask(self, edge_selector, log_probabilities=False):
whether those probabilities are on a log scale.
Take care to verify that you shouldn't instead use :meth:`HistoryDag.probability_annotate` with
a choice of ``edge_weight_func`` that takes into account the masking preferences."""
a choice of ``edge_weight_func`` that takes into account the masking preferences.
"""

if log_probabilities:
mask_value = float("-inf")
Expand All @@ -3137,7 +3146,6 @@ def set_sample_mask(self, edge_selector, log_probabilities=False):
if not val:
eset.probs[i] = mask_value


def probability_annotate(
self,
edge_weight_func,
Expand Down Expand Up @@ -3636,7 +3644,6 @@ def traverse(node: HistoryDagNode):
yield from gen



# DAG creation functions


Expand Down
21 changes: 14 additions & 7 deletions historydag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
Tuple,
Dict,
FrozenSet,
TYPE_CHECKING,
)
from copy import deepcopy
from historydag import utils
from historydag.utils import Weight, Label, UALabel

if TYPE_CHECKING:
from historydag.dag import PreorderTreeBuilder


class HistoryDagNode:
r"""A recursive representation of a history DAG object.
Expand Down Expand Up @@ -183,15 +187,15 @@ def add_edge(
prob_norm=prob_norm,
)

def _get_subhistory_by_subid(self, subid: int, tree_builder, parent_repr=None) -> "HistoryDagNode":
r"""Returns the subtree below the current HistoryDagNode corresponding
to the given index."""
def _get_subhistory_by_subid(
self, subid: int, tree_builder: "PreorderTreeBuilder", parent_repr=None
):
r"""Uses ``tree_builder`` to build the subtree below the current
HistoryDagNode corresponding to the given index."""
this_repr = tree_builder.add_node(self, parent=parent_repr)
if self.is_leaf(): # base case - the node is a leaf
return
else:
history = self.empty_copy()

# get the subtree for each of the clades
for clade, eset in self.clades.items():
# get the sum of subtrees of the edges for this clade
Expand All @@ -206,7 +210,9 @@ def _get_subhistory_by_subid(self, subid: int, tree_builder, parent_repr=None) -
curr_index = curr_index - child._dp_data
else:
# add this edge to the tree somehow
child._get_subhistory_by_subid(curr_index, tree_builder, parent_repr=this_repr)
child._get_subhistory_by_subid(
curr_index, tree_builder, parent_repr=this_repr
)
break

subid = subid / num_subtrees
Expand Down Expand Up @@ -525,7 +531,8 @@ def sample(
) -> Tuple[HistoryDagNode, float]:
"""Returns a randomly sampled child edge, and its corresponding weight.
When possible, only edges with nonzero mask value will be sampled.
When possible, only edges with nonzero mask value will be
sampled.
"""
if log_probabilities:
weights = [exp(weight) for weight in self.probs]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,10 @@ def test_fast_sample_with_node():
]
for node in least_supported_nodes:
mask_true = dag.nodes_above_node(node)

def edge_selector(edge):
return edge[-1] in mask_true

dag.make_uniform()
dag.set_sample_mask(edge_selector)
tree_samples = [dag.fast_sample() for _ in range(min_count * 7)]
Expand All @@ -411,6 +413,7 @@ def edge_selector(edge):
# print(norms)
# assert all(is_close(norm, avg) for norm in norms)


def test_sample_with_node():
random.seed(1)
dag = dags[-1]
Expand Down

0 comments on commit bab151d

Please sign in to comment.