Skip to content

Commit

Permalink
Replaced iter_pairs() with itertools.product(). Added sentence thresh…
Browse files Browse the repository at this point in the history
…olds. (#1)
  • Loading branch information
wejradford committed Apr 24, 2014
1 parent 09b9b4b commit d22f305
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 41 deletions.
64 changes: 36 additions & 28 deletions gigacluster/comparators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/usr/bin/env python3

from collections import Counter
import itertools
import math
import re
import string
from stopwords import STOPWORDS
import sys

from idf import IDF
from model import dr, prev_next_sentences
from .idf import IDF
from .model import dr, prev_next_sentences
from .stopwords import STOPWORDS

class Match(object):
def __init__(self, a, b, score, info):
Expand Down Expand Up @@ -69,15 +71,6 @@ def sentence_id(doc, sentence):
def sentence_text(doc, sentence):
return ' '.join(t.raw for t in doc.tokens[sentence.span]).replace('\t', ' ')

def iter_pairs(a_items, b_items, hook):
for i, a in enumerate(a_items):
hook(a)
for j, b in enumerate(b_items):
if i == j:
break # Do not compare a,b; b,a.
hook(b)
yield a, b

def iter_long(sentences, length):
for s in sentences:
if s.span.stop - s.span.start > length:
Expand All @@ -101,9 +94,21 @@ def __init__(self, threshold):
self.stats = Counter()

def __call__(self, docs_a, docs_b):
for a, b in iter_pairs(docs_a, docs_b, self.prime_features):
comparisons = 0
for i in docs_a:
self.prime_features(i)
for i in docs_b:
self.prime_features(i)
req_comparisons = len(docs_a) * len(docs_b)
step = req_comparisons // 10
print('{} comparisons'.format(req_comparisons), file=sys.stderr, end='')
for a, b in itertools.product(docs_a, docs_b):
if comparisons % step == 0:
print(' ...{}'.format(comparisons), file=sys.stderr, end='')
comparisons += 1
for m in self._handle(a, b):
yield m
print('', file=sys.stderr)

def _handle(self, a, b):
raise NotImplementedError
Expand All @@ -112,9 +117,11 @@ class DocSentenceComparator(Comparator):
def __init__(self, threshold, sentence_threshold, idf_path):
super(DocSentenceComparator, self).__init__(threshold)
self.idf = IDF(idf_path)
self.sentence_threshold = sentence_threshold
self.sentence_stats = Counter()

def __str__(self):
return '<{} t={} idf={}>'.format(self.__class__.__name__, self.threshold, self.idf)
return '<{} t={} st={} idf={}>'.format(self.__class__.__name__, self.threshold, self.sentence_threshold, self.idf)

def _handle(self, a, b):
matches = []
Expand All @@ -126,15 +133,14 @@ def _handle(self, a, b):
#features.sort(reverse=True)

# Check for sentence similarity.
# Perhaps some extra calls, but we don't know until this point whether we'll need to prime.
prime_sentence_features(a)
prime_sentence_features(b)
for i, j in iter_pairs(a.sentences, b.sentences, lambda i: i):
for i, j in itertools.product(a.sentences, b.sentences):
if not (i in a.sentence_features and j in b.sentence_features):
continue
a_f = a.sentence_features[i]
b_f = b.sentence_features[j]
card_intersection, card_union, card_a, card_b, sentence_score = overlap(set(a_f.keys()), set(b_f.keys()))
if not sentence_score or sentence_score < self.sentence_threshold:
continue
dot_dict = {k: a_f[k] * b_f[k] for k in set(a_f.keys()).intersection(set(b_f.keys()))}
matches.append(SentenceMatch(sentence_id(a, i), sentence_id(b, j),
score=score,
Expand All @@ -148,26 +154,28 @@ def _handle(self, a, b):
norm=norm(a_f) * norm(b_f),
info='{}\t{}'.format(sentence_text(a, i),
sentence_text(b, j))))
matches.sort(key=lambda m: m.sentence_score, reverse=True)
return matches[:5]
self.sentence_stats['{:.3f}'.format(sentence_score)] += 1
return matches

def prime_features(self, doc):
if not hasattr(doc, 'features'):
doc.features = sq_tfidf_unigrams(doc, self.idf)
prime_sentence_features(doc)

@property
def decile_quartile(self):
samples = list(self.iter_samples(self.stats))
def deciles(self):
return self._get_decile(self.stats), self._get_decile(self.sentence_stats)

def _get_decile(self, stats):
samples = list(self._iter_samples(stats))
dec = len(samples) // 10
quart = len(samples) // 4
return samples[-dec:-dec + 1], samples[-quart:-quart + 1]
return samples[-dec:-dec + 1]

def iter_samples(self, stats):
def _iter_samples(self, stats):
for i, count in sorted(stats.items()):
for j in range(count):
yield i


class SentenceBOWOverlap(Comparator):
def __init__(self, threshold, length, idf_path=None):
super(SentenceBOWOverlap, self).__init__(threshold)
Expand All @@ -184,7 +192,7 @@ def _handle(self, a, b):


matches = []
for s, t in iter_pairs(iter_long(a.sentences, self.length), iter_long(b.sentences, self.length), lambda i: i):
for s, t in itertools.product(iter_long(a.sentences, self.length), iter_long(b.sentences, self.length)):
if s in a.sentence_features and t in b.sentence_features:
score = self._score(a.sentence_features[s], b.sentence_features[t])
if score > self.threshold:
Expand All @@ -198,7 +206,7 @@ def prime_features(self, doc):
def _score(self, a_f, b_f):
card_intersection, card_union, card_a, card_b, sentence_score = overlap(a_f, b_f)
if self.idf:
sentence_score *= sum(self.idf.get(t) for t in a_features.intersection(b_features))
sentence_score *= sum(self.idf.get(t) for t in a_f.intersection(b_f))
return sentence_score

class SentenceBOWCosine(SentenceBOWOverlap):
Expand Down
2 changes: 1 addition & 1 deletion gigacluster/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import itertools
import os
from model import dr, Doc
from .model import dr, Doc

class StreamError(Exception): pass

Expand Down
14 changes: 9 additions & 5 deletions gigacluster/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def seek(self, date=None):
self._fill(start, end)
return True
else:
self._drop()
return False

def _calculate_date_bounds(self, date):
Expand All @@ -59,7 +60,7 @@ def _calculate_date_bounds(self, date):
self._push_front((d, docs))
start = date = d
for i in range(self.before):
start =- DAY
start -= DAY
end = date
for i in range(self.after):
end += DAY
Expand All @@ -75,10 +76,7 @@ def _calculate_date_bounds(self, date):

def _fill(self, start, end):
""" Fills buckets from the stream between start and end. """
# Drop unused buckets.
for d in list(self.dates.keys()):
if d < start:
del self.dates[d]
self._drop(start)
# Fill from steam.
while True:
d, docs = self._next_item()
Expand All @@ -93,6 +91,12 @@ def _fill(self, start, end):
self._push_front((d, docs)) # Save for later.
break

def _drop(self, start=None):
""" Drop unused buckets. """
for d in list(self.dates.keys()):
if start is None or d < start:
del self.dates[d]

def _next_item(self):
""" Return the next item. """
if self.extra:
Expand Down
17 changes: 10 additions & 7 deletions gigacluster/print_clusters.py → print_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import datetime
import sys

from stream import Stream
from window import Window
from comparators import *
from gigacluster.stream import Stream
from gigacluster.window import Window
from gigacluster.comparators import *

METRICS = {
'SentenceBOWOverlap': SentenceBOWOverlap,
Expand All @@ -19,7 +19,8 @@
parser.add_argument('-s', '--streams', default=[], action='append')
parser.add_argument('-S', '--stream-exp', help='RE to match stream filenames')
parser.add_argument('-m', '--metric', default='SentenceBOWOverlap', help='Metric, available={}'.format(METRICS.keys()))
parser.add_argument('-t', '--threshold', type=float, default=0.25)
parser.add_argument('-t', '--threshold', type=float, default=0.029)
parser.add_argument('-T', '--sentence-threshold', type=float, default=0.125)
parser.add_argument('-l', '--length', type=int, default=1)
parser.add_argument('-i', '--idf-path')
parser.add_argument('-e', '--end-date')
Expand All @@ -37,7 +38,7 @@
m = METRICS.get(args.metric)
if m is None:
parser.error('Require valid metric {}'.format(METRICS.keys()))
comparator = m(args.threshold, args.threshold, idf_path=args.idf_path)
comparator = m(args.threshold, args.sentence_threshold, idf_path=args.idf_path)
#comparator = m(args.threshold, length=args.length, idf_path=args.idf_path)

print(comparator, file=sys.stderr)
Expand All @@ -49,10 +50,12 @@
for w in secondaries:
w.seek(date)
print(' ', w, file=sys.stderr)
for match in comparator(docs, w.iter_docs()):
for match in comparator(docs, list(w.iter_docs())):
print('{}\t{}'.format(date, match))
print('Distribution: score={}\tsentence_score{}'.format(*comparator.deciles), file=sys.stderr)
sys.stdout.flush()
print('Distribution: {}\t{}'.format(*comparator.decile_quartile), file=sys.stderr)
sys.stderr.flush()

more = primary.seek()
if end_date and date == end_date:
break
14 changes: 14 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from gigacluster.comparators import *
from gigacluster.model import Doc, Token

def test_overlap():
assert 1 == overlap(set('abc'), set('bca'))[-1]
assert 0 == overlap(set('abc'), set('def'))[-1]
assert 1/3 == overlap(set('ab'), set('bc'))[-1]

def test_unigram_tf():
d = Doc()
for i in 'The cat in the mat .'.split():
d.tokens.append(Token(raw=i))
tf = unigram_tf(d, slice(0, len(d.tokens) + 1))
assert {'cat': 1, 'mat': 1} == dict(tf)
14 changes: 14 additions & 0 deletions tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
STREAM = [
(datetime.date(2014, 1, 1), ['20140101.a', '20140101.b']),
(datetime.date(2014, 1, 2), ['20140102.a', '20140102.b']),
(datetime.date(2014, 1, 3), ['20140103.a', '20140103.b']),
]

def test_window():
Expand All @@ -21,4 +22,17 @@ def test_window_gap():
w.seek()
print(list(w.iter_docs()))
assert sorted(STREAM[0][1] + STREAM[1][1]) == sorted(w.iter_docs())
assert w.seek()
assert sorted(STREAM[1][1] + STREAM[2][1]) == sorted(w.iter_docs())
assert not w.seek()

def test_empty_seek():
w = Window(STREAM, before=1)
assert w.seek(datetime.date(2014, 1, 4))
print(list(w.iter_docs()))
assert list(w.iter_docs())
print(w.date)
assert not w.seek()
print(w.date)
print(list(w.iter_docs()))
assert not list(w.iter_docs())

0 comments on commit d22f305

Please sign in to comment.