Skip to content

Commit

Permalink
Merge pull request #155 from mattdahl/issue-147-fix-hashing
Browse files Browse the repository at this point in the history
Issue 147, 154 - Improve citation object hashing behavior
  • Loading branch information
flooie authored Sep 22, 2023
2 parents 0985138 + a519873 commit 65832f9
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 48 deletions.
152 changes: 122 additions & 30 deletions eyecite/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections import UserString
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import (
Any,
Expand All @@ -15,7 +15,7 @@
cast,
)

from eyecite.utils import HashableDict
from eyecite.utils import hash_sha256

ResourceType = Hashable

Expand Down Expand Up @@ -60,7 +60,7 @@ def includes_year(
)


@dataclass(eq=True, unsafe_hash=True)
@dataclass(eq=False, unsafe_hash=False)
class CitationBase:
"""Base class for objects returned by `eyecite.find.get_citations`. We
define several subclasses of this class below, representing the various
Expand All @@ -79,7 +79,7 @@ class CitationBase:
def __post_init__(self):
"""Set up groups and metadata."""
# Allow groups to be used in comparisons:
self.groups = HashableDict(self.token.groups)
self.groups = self.token.groups
# Make metadata a self.Metadata object:
self.metadata = (
self.Metadata(**self.metadata)
Expand All @@ -101,21 +101,52 @@ def __repr__(self):
+ ")"
)

def __hash__(self) -> int:
"""In general, citations are considered equivalent if they have the
same group values (i.e., the same regex group content that is extracted
from the matched text). Subclasses may override this method in order to
specify equivalence behavior that is more appropriate for certain
kinds of citations (e.g., see CaseCitation override).
self.groups typically contains different keys for different objects:
FullLawCitation (non-exhaustive and non-guaranteed):
- chapter
- reporter
- law_section
- issue
- page
- docket_number
- pamphlet
- title
FullJournalCitation (non-exhaustive and non-guaranteed):
- volume
- reporter
- page
FullCaseCitation (see CaseCitation.__hash__() notes)
"""
return hash(
hash_sha256(
{**dict(self.groups.items()), **{"class": type(self).__name__}}
)
)

def __eq__(self, other):
"""This method is inherited by all subclasses and should not be
overridden. It implements object equality in exactly the same way as
defined in an object's __hash__() function, which should be overridden
instead if desired.
"""
return self.__hash__() == other.__hash__()

@dataclass(eq=True, unsafe_hash=True)
class Metadata:
"""Define fields on self.metadata."""

parenthetical: Optional[str] = None

def comparison_hash(self) -> int:
"""Return hash that will be the same if two cites are semantically
equivalent, unless the citation is a CaseCitation missing a page.
"""
if isinstance(self, CaseCitation) and self.groups["page"] is None:
return id(self)
else:
return hash((type(self), tuple(self.groups.items())))

def corrected_citation(self):
"""Return citation with any variations normalized."""
return self.matched_text()
Expand Down Expand Up @@ -170,7 +201,7 @@ def full_span(self) -> Tuple[int, int]:
return start, end


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class ResourceCitation(CitationBase):
"""Base class for a case, law, or journal citation. Could be short or
long."""
Expand All @@ -194,18 +225,33 @@ def __post_init__(self):
)
super().__post_init__()

def __hash__(self) -> int:
"""ResourceCitation objects are hashed in the same way as their
parent class (CitationBase) objects, except that we also take into
consideration the all_editions field.
"""
return hash(
hash_sha256(
{
**dict(self.groups.items()),
**{
"all_editions": sorted(
[asdict(e) for e in self.all_editions],
key=lambda d: d["short_name"], # type: ignore
),
"class": type(self).__name__,
},
}
)
)

@dataclass(eq=True, unsafe_hash=True)
class Metadata(CitationBase.Metadata):
"""Define fields on self.metadata."""

pin_cite: Optional[str] = None
year: Optional[str] = None

def comparison_hash(self) -> int:
"""Return hash that will be the same if two cites are semantically
equivalent."""
return hash((super().comparison_hash(), self.all_editions))

def add_metadata(self, words: "Tokens"):
"""Extract metadata from text before and after citation."""
self.guess_edition()
Expand Down Expand Up @@ -248,13 +294,13 @@ def guess_edition(self):
self.edition_guess = editions[0]


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class FullCitation(ResourceCitation):
"""Abstract base class indicating that a citation fully identifies a
resource."""


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class FullLawCitation(FullCitation):
"""Citation to a source from `reporters_db/laws.json`."""

Expand Down Expand Up @@ -291,7 +337,7 @@ def corrected_citation_full(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class FullJournalCitation(FullCitation):
"""Citation to a source from `reporters_db/journals.json`."""

Expand All @@ -317,12 +363,43 @@ def corrected_citation_full(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class CaseCitation(ResourceCitation):
"""Convenience class which represents a single citation found in a
document.
"""

def __hash__(self) -> int:
"""CaseCitation objects that have the same volume, reporter, and page
are considered equivalent, unless the citation is missing a page, in
which case the object's hash will be unique for safety.
self.groups for CaseCitation objects usually contains these keys:
- page (guaranteed here: https://github.com/freelawproject/reporters-db/blob/main/tests.py#L129) # noqa: E501
- reporter (guaranteed here: https://github.com/freelawproject/reporters-db/blob/main/tests.py#L129) # noqa: E501
- volume (almost always present, but some tax court citations don't have volumes) # noqa: E501
- reporter_nominative (sometimes)
- volumes_nominative (sometimes)
"""
if self.groups["page"] is None:
return id(self)
else:
return hash(
hash_sha256(
{
**{
k: self.groups[k]
for k in ["volume", "page"]
if k in self.groups
},
**{
"reporter": self.corrected_reporter(),
"class": type(self).__name__,
},
}
)
)

@dataclass(eq=True, unsafe_hash=True)
class Metadata(FullCitation.Metadata):
"""Define fields on self.metadata."""
Expand All @@ -339,7 +416,7 @@ def guess_court(self):
self.metadata.court = "scotus"


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class FullCaseCitation(CaseCitation, FullCitation):
"""Convenience class which represents a standard, fully named citation,
i.e., the kind of citation that marks the first time a document is cited.
Expand Down Expand Up @@ -389,7 +466,7 @@ def corrected_citation_full(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class ShortCaseCitation(CaseCitation):
"""Convenience class which represents a short form citation, i.e., the kind
of citation made after a full citation has already appeared. This kind of
Expand Down Expand Up @@ -419,7 +496,7 @@ def corrected_citation_full(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class SupraCitation(CitationBase):
"""Convenience class which represents a 'supra' citation, i.e., a citation
to something that is above in the document. Like a short form citation,
Expand Down Expand Up @@ -458,7 +535,7 @@ def formatted(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class IdCitation(CitationBase):
"""Convenience class which represents an 'id' or 'ibid' citation, i.e., a
citation to the document referenced immediately prior. An 'id' citation is
Expand All @@ -469,6 +546,10 @@ class IdCitation(CitationBase):
Example: "... foo bar," id., at 240
"""

def __hash__(self) -> int:
"""IdCitation objects are always considered unique for safety."""
return id(self)

@dataclass(eq=True, unsafe_hash=True)
class Metadata(CitationBase.Metadata):
"""Define fields on self.metadata."""
Expand All @@ -483,14 +564,18 @@ def formatted(self):
return "".join(parts)


@dataclass(eq=True, unsafe_hash=True, repr=False)
@dataclass(eq=False, unsafe_hash=False, repr=False)
class UnknownCitation(CitationBase):
"""Convenience class which represents an unknown citation. A recognized
citation should theoretically be parsed as a CaseCitation, FullLawCitation,
or a FullJournalCitation. If it's something else, this class serves as
a naive catch-all.
"""

def __hash__(self) -> int:
"""UnknownCitation objects are always considered unique for safety."""
return id(self)


@dataclass(eq=True, unsafe_hash=True)
class Token(UserString):
Expand Down Expand Up @@ -636,13 +721,20 @@ class Resource(ResourceType):

def __hash__(self):
"""Resources are the same if their citations are semantically
equivalent.
equivalent, as defined by their hash function.
Note: Resources composed of citations with missing page numbers are
NOT considered the same, even if their other attributes are identical.
This is to avoid potential false positives.
"""
return self.citation.comparison_hash()
return hash(
hash_sha256(
{
"citation": hash(self.citation),
"class": type(self).__name__,
}
)
)

def __eq__(self, other):
return self.__hash__() == other.__hash__()
4 changes: 2 additions & 2 deletions eyecite/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def case_citation(


def law_citation(
source_text,
reporter,
source_text=None,
reporter="Mass. Gen. Laws",
**kwargs,
):
"""Convenience function for creating mock FullLawCitation objects."""
Expand Down
26 changes: 19 additions & 7 deletions eyecite/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import re

from lxml import etree
Expand Down Expand Up @@ -72,13 +74,6 @@ def on_match(index, start, end, flags, context):
return matches


class HashableDict(dict):
"""Dict that works as an attribute of a hashable dataclass."""

def __hash__(self):
return hash(frozenset(self.items()))


def dump_citations(citations, text, context_chars=30):
"""Dump citations extracted from text, for debugging. Example:
>>> text = "blah. Foo v. Bar, 1 U.S. 1, 2 (1999). blah"
Expand Down Expand Up @@ -117,3 +112,20 @@ def dump_citations(citations, text, context_chars=30):
else:
out.append(f" * {key}={repr(value)}")
return "\n".join(out)


def hash_sha256(dictionary: dict) -> int:
"""Hash dictionaries in a deterministic way.
:param dictionary: The dictionary to hash
:return: An integer hash
"""

# Convert the dictionary to a JSON string
json_str: str = json.dumps(dictionary, sort_keys=True)

# Convert the JSON string to bytes
json_bytes: bytes = json_str.encode("utf-8")

# Calculate the hash of the bytes, convert to an int, and return
return int.from_bytes(hashlib.sha256(json_bytes).digest(), byteorder="big")
Loading

0 comments on commit 65832f9

Please sign in to comment.