Skip to content

Commit

Permalink
add rootings and one-sided distances to all RF methods
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Dec 13, 2023
1 parent c566556 commit 3707fea
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 128 deletions.
104 changes: 76 additions & 28 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections import Counter, namedtuple
from copy import deepcopy
from historydag import utils
from historydag.utils import Weight, Label, UALabel, prod
from historydag.utils import Weight, Label, UALabel, prod, TaxaError
from historydag.counterops import counter_sum, counter_prod
import historydag.parsimony_utils as parsimony_utils
from historydag.dag_node import (
Expand Down Expand Up @@ -2049,8 +2049,10 @@ def count_nodes(self, collapse=False, rooted=True) -> Dict[HistoryDagNode, int]:
split2adjustment = {}
all_taxa = next(self.dagroot.children()).clade_union()
if any(all_taxa != n.clade_union() for n in self.dagroot.children()):
raise ValueError("Unrooted splits cannot be counted properly because"
" trees in this dag are on different sets of taxa.")
raise TaxaError(
"Unrooted splits cannot be counted properly because"
" trees in this dag are on different sets of taxa."
)
for treeroot in self.dagroot.children():
if len(treeroot.clades) == 2:
split = frozenset(treeroot.clades.keys())
Expand Down Expand Up @@ -2309,7 +2311,7 @@ def optimal_sum_rf_distance(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients
one_sided_coefficients=one_sided_coefficients,
)
return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func)

Expand All @@ -2324,6 +2326,9 @@ def trim_optimal_sum_rf_distance(
"""Trims the DAG to contain only histories with the optimal (min or
max) sum rooted RF distance to the given reference DAG.
See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of
arguments.
Trimming to the minimum sum RF distance is equivalent to finding 'median' topologies,
and trimming to maximum sum rf distance is equivalent to finding topological outliers.
Expand All @@ -2337,7 +2342,7 @@ def trim_optimal_sum_rf_distance(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients
one_sided_coefficients=one_sided_coefficients,
)
return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func)

Expand All @@ -2352,6 +2357,9 @@ def trim_optimal_rf_distance(
"""Trims this history DAG to the optimal (min or max) RF distance to a
given history.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
Also returns that optimal RF distance
The given history must be on the same taxa as all trees in the DAG.
Expand All @@ -2378,6 +2386,9 @@ def optimal_rf_distance(
):
"""Returns the optimal (min or max) RF distance to a given history.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
The given history must be on the same taxa as all trees in the DAG.
Since computing reference splits is expensive, it is better to use
:meth:`optimal_weight_annotate` and :meth:`utils.make_rfdistance_countfuncs`
Expand All @@ -2403,6 +2414,9 @@ def count_rf_distances(
The given history must be on the same taxa as all trees in the DAG.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
Since computing reference splits is expensive, it is better to use
:meth:`weight_count` and :meth:`utils.make_rfdistance_countfuncs`
instead of making multiple calls to this method with the same reference
Expand All @@ -2416,25 +2430,62 @@ def count_rf_distances(
)
return self.weight_count(**kwargs)

def count_sum_rf_distances(self, reference_dag: "HistoryDag", rooted: bool = False):
def count_sum_rf_distances(
self,
reference_dag: "HistoryDag",
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
):
"""Returns a Counter containing all sum RF distances to a given
reference DAG.
See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of
arguments.
The given history DAG must be on the same taxa as all trees in the DAG.
Since computing reference splits is expensive, it is better to use
:meth:`weight_count` and :meth:`utils.sum_rfdistance_funcs`
instead of making multiple calls to this method with the same reference
history DAG.
"""
kwargs = utils.sum_rfdistance_funcs(reference_dag)
kwargs = utils.sum_rfdistance_funcs(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.weight_count(**kwargs)

def sum_rf_distances(self, reference_dag: "HistoryDag" = None):
r"""Computes the sum of all Robinson-Foulds distances between a history
in this DAG and a history in the reference DAG.
def sum_rf_distances(
self,
reference_dag: "HistoryDag" = None,
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
):
r"""Computes the sum of Robinson-Foulds distances over all pairs of
histories in this DAG and the provided reference DAG.
Args:
reference_dag: If None, the sum of pairwise distances between histories in this DAG
is computed. If provided, the sum is over pairs containing one history in this DAG and
one from ``reference_dag``.
rooted: If False, use edges' splits for RF distance computation. Otherwise, use
the clade below each edge.
one_sided: May be 'left', 'right', or None. 'left' means that we count
splits (or clades, in the rooted case) which are in the reference trees but not
in the DAG tree, especially useful if trees in the DAG might be resolutions of
multifurcating trees in the reference DAG. 'right' means that we count splits or clades in
the DAG tree which are not in the reference trees, useful if the reference trees
are possibly resolutions of multifurcating trees in the DAG. If not None,
one_sided_coefficients are ignored.
one_sided_coefficients: coefficients for non-standard symmetric difference calculations.
See :meth:`utils.make_rfdistance_countfuncs` for more details.
This is rooted RF distance.
Returns:
An integer sum of RF distances.
If T is the set of histories in the reference DAG, and T' is the set of histories in
this DAG, then the returned sum is:
Expand All @@ -2446,22 +2497,16 @@ def sum_rf_distances(self, reference_dag: "HistoryDag" = None):
That is, since RF distance is symmetric, when T = T' (such as when ``reference_dag=None``),
or when the intersection of T and T' is nonempty, some distances are counted twice.
Args:
reference_dag: If None, the sum of pairwise distances between histories in this DAG
is computed. If provided, the sum is over pairs containing one history in this DAG and
one from ``reference_dag``.
Returns:
An integer sum of RF distances.
Note that when computing one-sided distances, or when the one_sided_coefficients values are not
equal, this 'distance' is no longer symmetric.
"""
s, t, _ = utils._process_rf_one_sided_coefficients(
one_sided, one_sided_coefficients
)

def get_data(dag):
n_histories = dag.count_histories()
N = dag.count_nodes(collapse=True)
try:
N.pop(frozenset())
except KeyError:
pass
N = dag.count_nodes(collapse=True, rooted=rooted)

clade_count_sum = sum(N.values())
return (n_histories, N, clade_count_sum)
Expand All @@ -2484,13 +2529,13 @@ def get_data(dag):
)

return (
n_histories * clade_count_sum_prime
+ n_histories_prime * clade_count_sum
- 2 * intersection_term
t * n_histories * clade_count_sum_prime
+ s * n_histories_prime * clade_count_sum
- (s + t) * intersection_term
)

def average_pairwise_rf_distance(
self, reference_dag: "HistoryDag" = None, non_identical=True
self, reference_dag: "HistoryDag" = None, non_identical=True, **kwargs
):
"""Return the average Robinson-Foulds distance between pairs of
histories.
Expand All @@ -2499,6 +2544,7 @@ def average_pairwise_rf_distance(
reference_dag: A history DAG from which to take the second history in
each pair. If None, ``self`` will be used as the reference.
non_identical: If True, mean divisor will be the number of non-identical pairs.
kwargs: See :meth:`historydag.sum_rf_distances` for additional keyword arguments
Returns:
The average rf-distance between pairs of histories, where the first history
Expand All @@ -2507,7 +2553,9 @@ def average_pairwise_rf_distance(
``non_identical`` is True, in which case the number of histories which appear
in both DAGs is subtracted from this constant.
"""
sum_pairwise_distance = self.sum_rf_distances(reference_dag=reference_dag)
sum_pairwise_distance = self.sum_rf_distances(
reference_dag=reference_dag, **kwargs
)
if reference_dag is None:
# ignore the diagonal in the distance matrix, since it contains
# zeros:
Expand Down
Loading

0 comments on commit 3707fea

Please sign in to comment.