diff --git a/historydag/dag.py b/historydag/dag.py index bb701b5..2ad19d9 100644 --- a/historydag/dag.py +++ b/historydag/dag.py @@ -101,11 +101,15 @@ def convert(dag, newclass): # Preorder tree creation class TreeBuilderNode = Any + + class PreorderTreeBuilder: - """Any class implementing a PreorderTreeBuilder interface can be used as a tree sample - constructor in :meth:`HistoryDag.fast_sample`. Subclasses implementing this interface may - implement an arbitrary constructor interface, as the user will be responsible for creating - instances to be used for sampling. In addition, subclasses must implement the following methods: + """Any class implementing a PreorderTreeBuilder interface can be used as a + tree sample constructor in :meth:`HistoryDag.fast_sample`. Subclasses + implementing this interface may implement an arbitrary constructor + interface, as the user will be responsible for creating instances to be + used for sampling. In addition, subclasses must implement the following + methods: Methods: add_node: This method must accept a :class:HistoryDagNode object ``dag_node`` and, optionally @@ -119,11 +123,11 @@ class PreorderTreeBuilder: sampled tree, after any necessary clean-up or final tree construction steps. Its return value is the return value of :meth:`HistoryDag.fast_sample`. """ + pass class EteTreeBuilder(PreorderTreeBuilder): - def __init__( self, name_func: Callable[[HistoryDagNode], str] = lambda n: "unnamed", @@ -138,6 +142,7 @@ def __init__( def feature_func(node): return getattr(node.label, feature) + self.feature_funcs[feature] = feature_func self.feature_funcs = tuple(self.feature_funcs.items()) @@ -164,6 +169,7 @@ def add_node( def get_finished_tree(self): return self.treeroot + class PreorderHistoryBuilder(PreorderTreeBuilder): def __init__( self, @@ -191,7 +197,6 @@ def get_finished_tree(self): parent.add_edge(child) return self.dag_type(self.root_node) - class HistoryDag: r"""An object to represent a collection of internally labeled trees. A @@ -400,9 +405,9 @@ def __getitem__(self, key) -> "HistoryDag": def get_histories_by_index(self, key_iterator, tree_builder_func=None): """Retrieving a history by index is slow, since each retrieval requires - running the ``trim_optimal_weight`` method on the entire DAG to populate - node counts. This method instead runs that method a single time and - yields a history for each index yielded by ``key_iterator``. + running the ``trim_optimal_weight`` method on the entire DAG to + populate node counts. This method instead runs that method a single + time and yields a history for each index yielded by ``key_iterator``. Args: key_iterator: An iterator on desired history indices. May be consumable, as @@ -410,7 +415,8 @@ def get_histories_by_index(self, key_iterator, tree_builder_func=None): tree_builder_func: A function accepting an index and returning a :class:`PreorderTreeBuilder` instance to be used to build the history with that index. If None (default), then tree-shaped HistoryDag objects - will be yielded using :class:`PreorderHistoryBuilder`.""" + will be yielded using :class:`PreorderHistoryBuilder`. + """ if tree_builder_func is None: def tree_builder_func(idx): @@ -422,12 +428,13 @@ def tree_builder_func(idx): if key < 0: key = length + key if not (key >= 0 and key < length): - raise IndexError(f"Invalid index {key} in DAG containing {length} histories") + raise IndexError( + f"Invalid index {key} in DAG containing {length} histories" + ) builder = tree_builder_func(key) self.dagroot._get_subhistory_by_subid(key, builder) yield builder.get_finished_tree() - def get_label_type(self) -> type: """Return the type for labels on this dag's nodes.""" return type(next(self.dagroot.children()).label) @@ -811,14 +818,13 @@ def find_node( def fast_sample( self, - tree_builder: PreorderTreeBuilder=None, + tree_builder: PreorderTreeBuilder = None, log_probabilities=False, ): - """This is a non-recursive alternative to :meth:`HistoryDag.sample`, which is likely - to be slower on small DAGs, but may allow significant optimizations on large DAGs, or - in the case that the data format being sampled is something other than a tree-shaped - HistoryDag object. - + """This is a non-recursive alternative to :meth:`HistoryDag.sample`, + which is likely to be slower on small DAGs, but may allow significant + optimizations on large DAGs, or in the case that the data format being + sampled is something other than a tree-shaped HistoryDag object. This method does not provide an edge_selector argument like :meth:`HistoryDag.sample`. Instead, any masking of edges should be done prior to sampling using the :meth:`HistoryDag.set_sample_mask` @@ -827,7 +833,8 @@ def fast_sample( Args: tree_builder: a PreorderTreeBuilder instance to handle construction of the sampled tree. log_probabilities: Whether edge probabilities annotated on this DAG (using, for example, - :meth:`HistoryDag.probability_annotate`) are on a log-scale.""" + :meth:`HistoryDag.probability_annotate`) are on a log-scale. + """ if tree_builder is None: tree_builder = PreorderHistoryBuilder(type(self)) @@ -846,7 +853,6 @@ def get_sampled_children(node): return tree_builder.get_finished_tree() - def sample( self, edge_selector=lambda e: True, log_probabilities=False ) -> "HistoryDag": @@ -1341,17 +1347,18 @@ def to_ete( """ # First build a dictionary of ete3 nodes keyed by HDagNodes. if features is None: - features = list( - list(self.dagroot.children())[0].label._asdict().keys() - ) + features = list(list(self.dagroot.children())[0].label._asdict().keys()) - - tree_builder = EteTreeBuilder(name_func=name_func, features=features, feature_funcs=feature_funcs) + tree_builder = EteTreeBuilder( + name_func=name_func, features=features, feature_funcs=feature_funcs + ) nodes_to_process = [(self.dagroot, tree_builder.add_node(self.dagroot))] while len(nodes_to_process) > 0: node, node_repr = nodes_to_process.pop() for child in node.children(): - nodes_to_process.append((child, tree_builder.add_node(child, parent=node_repr))) + nodes_to_process.append( + (child, tree_builder.add_node(child, parent=node_repr)) + ) return tree_builder.get_finished_tree() @@ -3110,8 +3117,9 @@ def edge_probabilities( return {key: aggregate_func(val) for key, val in edge_probabilities.items()} def set_sample_mask(self, edge_selector, log_probabilities=False): - """Zero out edge weights for masked edges before calling :meth:`HistoryDag.fast_sample`. - This should be equivalent to passing the same edge_selector function to :meth:`HistoryDag.sample`. + """Zero out edge weights for masked edges before calling + :meth:`HistoryDag.fast_sample`. This should be equivalent to passing + the same edge_selector function to :meth:`HistoryDag.sample`. Args: edge_selector: A function accepting an edge (a tuple of HistoryDagNode objects) and @@ -3121,7 +3129,8 @@ def set_sample_mask(self, edge_selector, log_probabilities=False): whether those probabilities are on a log scale. Take care to verify that you shouldn't instead use :meth:`HistoryDag.probability_annotate` with - a choice of ``edge_weight_func`` that takes into account the masking preferences.""" + a choice of ``edge_weight_func`` that takes into account the masking preferences. + """ if log_probabilities: mask_value = float("-inf") @@ -3137,7 +3146,6 @@ def set_sample_mask(self, edge_selector, log_probabilities=False): if not val: eset.probs[i] = mask_value - def probability_annotate( self, edge_weight_func, @@ -3636,7 +3644,6 @@ def traverse(node: HistoryDagNode): yield from gen - # DAG creation functions diff --git a/historydag/dag_node.py b/historydag/dag_node.py index 60a03ff..a74be45 100644 --- a/historydag/dag_node.py +++ b/historydag/dag_node.py @@ -13,11 +13,15 @@ Tuple, Dict, FrozenSet, + TYPE_CHECKING, ) from copy import deepcopy from historydag import utils from historydag.utils import Weight, Label, UALabel +if TYPE_CHECKING: + from historydag.dag import PreorderTreeBuilder + class HistoryDagNode: r"""A recursive representation of a history DAG object. @@ -183,15 +187,15 @@ def add_edge( prob_norm=prob_norm, ) - def _get_subhistory_by_subid(self, subid: int, tree_builder, parent_repr=None) -> "HistoryDagNode": - r"""Returns the subtree below the current HistoryDagNode corresponding - to the given index.""" + def _get_subhistory_by_subid( + self, subid: int, tree_builder: "PreorderTreeBuilder", parent_repr=None + ): + r"""Uses ``tree_builder`` to build the subtree below the current + HistoryDagNode corresponding to the given index.""" this_repr = tree_builder.add_node(self, parent=parent_repr) if self.is_leaf(): # base case - the node is a leaf return else: - history = self.empty_copy() - # get the subtree for each of the clades for clade, eset in self.clades.items(): # get the sum of subtrees of the edges for this clade @@ -206,7 +210,9 @@ def _get_subhistory_by_subid(self, subid: int, tree_builder, parent_repr=None) - curr_index = curr_index - child._dp_data else: # add this edge to the tree somehow - child._get_subhistory_by_subid(curr_index, tree_builder, parent_repr=this_repr) + child._get_subhistory_by_subid( + curr_index, tree_builder, parent_repr=this_repr + ) break subid = subid / num_subtrees @@ -525,7 +531,8 @@ def sample( ) -> Tuple[HistoryDagNode, float]: """Returns a randomly sampled child edge, and its corresponding weight. - When possible, only edges with nonzero mask value will be sampled. + When possible, only edges with nonzero mask value will be + sampled. """ if log_probabilities: weights = [exp(weight) for weight in self.probs] diff --git a/tests/test_factory.py b/tests/test_factory.py index e3c175c..28aaa0f 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -394,8 +394,10 @@ def test_fast_sample_with_node(): ] for node in least_supported_nodes: mask_true = dag.nodes_above_node(node) + def edge_selector(edge): return edge[-1] in mask_true + dag.make_uniform() dag.set_sample_mask(edge_selector) tree_samples = [dag.fast_sample() for _ in range(min_count * 7)] @@ -411,6 +413,7 @@ def edge_selector(edge): # print(norms) # assert all(is_close(norm, avg) for norm in norms) + def test_sample_with_node(): random.seed(1) dag = dags[-1]