Skip to content

Commit

Permalink
IntState vs FloatState fix for RF distances (#81)
Browse files Browse the repository at this point in the history
* use FloatState for RF distances when necessary

* format and lint
  • Loading branch information
willdumm authored Dec 13, 2023
1 parent 262eae4 commit 26fcb93
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2500,7 +2500,7 @@ def sum_rf_distances(
Note that when computing one-sided distances, or when the one_sided_coefficients values are not
equal, this 'distance' is no longer symmetric.
"""
s, t, _ = utils._process_rf_one_sided_coefficients(
s, t, _, _ = utils._process_rf_one_sided_coefficients(
one_sided, one_sided_coefficients
)

Expand Down
28 changes: 17 additions & 11 deletions historydag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,17 @@ def natural_edge_probability(parent, child):

def _process_rf_one_sided_coefficients(one_sided, one_sided_coefficients):
rf_type_suffix = "distance"
if one_sided_coefficients != (1, 1):
rf_type_suffix = "nonstandard"
RFType = IntState

if one_sided is None:
pass
# Only then will one_sided_coefficients be considered
if one_sided_coefficients != (1, 1):
rf_type_suffix = "nonstandard"
# As long as both coefficients are integers, RF distances will
# be integers. Otherwise, we need to allow floats by using
# FloatState objects.
if not all(isinstance(it, int) for it in one_sided_coefficients):
RFType = FloatState
elif one_sided.lower() == "left":
one_sided_coefficients = (1, 0)
rf_type_suffix = "left_difference"
Expand All @@ -564,7 +570,7 @@ def _process_rf_one_sided_coefficients(one_sided, one_sided_coefficients):
)

s, t = one_sided_coefficients
return s, t, rf_type_suffix
return s, t, rf_type_suffix, RFType


def sum_rfdistance_funcs(
Expand Down Expand Up @@ -616,7 +622,7 @@ def sum_rfdistance_funcs(
Weights are represented by an IntState object and are shifted by a constant K,
which is the sum of number of clades in each tree in the DAG.
"""
s, t, rf_type_suffix = _process_rf_one_sided_coefficients(
s, t, rf_type_suffix, RFType = _process_rf_one_sided_coefficients(
one_sided, one_sided_coefficients
)

Expand All @@ -631,7 +637,7 @@ def sum_rfdistance_funcs(
if rooted:

def make_intstate(n):
return IntState(n + K, state=n)
return RFType(n + K, state=n)

def edge_func(n1, n2):
clade = n2.clade_union()
Expand Down Expand Up @@ -668,7 +674,7 @@ def split(node):
# added exactly once

def make_intstate(tup):
return IntState(tup[0] + tup[1] + K, state=tup)
return RFType(tup[0] + tup[1] + K, state=tup)

def summer(tupseq):
tupseq = list(tupseq)
Expand Down Expand Up @@ -758,7 +764,7 @@ def make_rfdistance_countfuncs(
``one_sided_coefficients`` ``(s, t)``, which affect how much weight is given to
the right and left differences in the RF distance calculation:
``|B ^ A| = s|B - A| + t|A - B| = s(|B| - |A n B|) + t|A - B|``
``d_{s,t}(A, B) = s|B - A| + t|A - B| = s(|B| - |A n B|) + t|A - B|``
When both ``s`` and ``t`` are 1, we get the standard RF distance.
When ``s=1`` and ``t=0``, then we have a one-sided "left" RF difference, counting
Expand All @@ -777,7 +783,7 @@ def make_rfdistance_countfuncs(
of the IntState is computed as `a + sign(b) + |B|`, which on the UA node of the hDAG gives RF distance.
"""

s, t, rf_type_suffix = _process_rf_one_sided_coefficients(
s, t, rf_type_suffix, RFType = _process_rf_one_sided_coefficients(
one_sided, one_sided_coefficients
)

Expand Down Expand Up @@ -815,7 +821,7 @@ def summer(tupseq):
return (a, b)

def make_intstate(tup):
return IntState(tup[0] + shift + sign(tup[1]), state=tup)
return RFType(tup[0] + shift + sign(tup[1]), state=tup)

def edge_func(n1, n2):
spl = split(n2)
Expand Down Expand Up @@ -850,7 +856,7 @@ def edge_func(n1, n2):
shift = s * len(ref_cus)

def make_intstate(n):
return IntState(n + shift, state=n)
return RFType(n + shift, state=n)

def edge_func(n1, n2):
if n2.clade_union() in ref_cus:
Expand Down

0 comments on commit 26fcb93

Please sign in to comment.