diff --git a/historydag/dag.py b/historydag/dag.py index 9cbb29c..fab0e96 100644 --- a/historydag/dag.py +++ b/historydag/dag.py @@ -2500,7 +2500,7 @@ def sum_rf_distances( Note that when computing one-sided distances, or when the one_sided_coefficients values are not equal, this 'distance' is no longer symmetric. """ - s, t, _ = utils._process_rf_one_sided_coefficients( + s, t, _, _ = utils._process_rf_one_sided_coefficients( one_sided, one_sided_coefficients ) diff --git a/historydag/utils.py b/historydag/utils.py index 619c835..84b8aaa 100644 --- a/historydag/utils.py +++ b/historydag/utils.py @@ -547,11 +547,17 @@ def natural_edge_probability(parent, child): def _process_rf_one_sided_coefficients(one_sided, one_sided_coefficients): rf_type_suffix = "distance" - if one_sided_coefficients != (1, 1): - rf_type_suffix = "nonstandard" + RFType = IntState if one_sided is None: - pass + # Only then will one_sided_coefficients be considered + if one_sided_coefficients != (1, 1): + rf_type_suffix = "nonstandard" + # As long as both coefficients are integers, RF distances will + # be integers. Otherwise, we need to allow floats by using + # FloatState objects. + if not all(isinstance(it, int) for it in one_sided_coefficients): + RFType = FloatState elif one_sided.lower() == "left": one_sided_coefficients = (1, 0) rf_type_suffix = "left_difference" @@ -564,7 +570,7 @@ def _process_rf_one_sided_coefficients(one_sided, one_sided_coefficients): ) s, t = one_sided_coefficients - return s, t, rf_type_suffix + return s, t, rf_type_suffix, RFType def sum_rfdistance_funcs( @@ -616,7 +622,7 @@ def sum_rfdistance_funcs( Weights are represented by an IntState object and are shifted by a constant K, which is the sum of number of clades in each tree in the DAG. """ - s, t, rf_type_suffix = _process_rf_one_sided_coefficients( + s, t, rf_type_suffix, RFType = _process_rf_one_sided_coefficients( one_sided, one_sided_coefficients ) @@ -631,7 +637,7 @@ def sum_rfdistance_funcs( if rooted: def make_intstate(n): - return IntState(n + K, state=n) + return RFType(n + K, state=n) def edge_func(n1, n2): clade = n2.clade_union() @@ -668,7 +674,7 @@ def split(node): # added exactly once def make_intstate(tup): - return IntState(tup[0] + tup[1] + K, state=tup) + return RFType(tup[0] + tup[1] + K, state=tup) def summer(tupseq): tupseq = list(tupseq) @@ -758,7 +764,7 @@ def make_rfdistance_countfuncs( ``one_sided_coefficients`` ``(s, t)``, which affect how much weight is given to the right and left differences in the RF distance calculation: - ``|B ^ A| = s|B - A| + t|A - B| = s(|B| - |A n B|) + t|A - B|`` + ``d_{s,t}(A, B) = s|B - A| + t|A - B| = s(|B| - |A n B|) + t|A - B|`` When both ``s`` and ``t`` are 1, we get the standard RF distance. When ``s=1`` and ``t=0``, then we have a one-sided "left" RF difference, counting @@ -777,7 +783,7 @@ def make_rfdistance_countfuncs( of the IntState is computed as `a + sign(b) + |B|`, which on the UA node of the hDAG gives RF distance. """ - s, t, rf_type_suffix = _process_rf_one_sided_coefficients( + s, t, rf_type_suffix, RFType = _process_rf_one_sided_coefficients( one_sided, one_sided_coefficients ) @@ -815,7 +821,7 @@ def summer(tupseq): return (a, b) def make_intstate(tup): - return IntState(tup[0] + shift + sign(tup[1]), state=tup) + return RFType(tup[0] + shift + sign(tup[1]), state=tup) def edge_func(n1, n2): spl = split(n2) @@ -850,7 +856,7 @@ def edge_func(n1, n2): shift = s * len(ref_cus) def make_intstate(n): - return IntState(n + shift, state=n) + return RFType(n + shift, state=n) def edge_func(n1, n2): if n2.clade_union() in ref_cus: