Skip to content

Commit

Permalink
MAINT: work to speed up AssignmentSet
Browse files Browse the repository at this point in the history
  • Loading branch information
johntyree committed Jan 31, 2016
1 parent 1ff1d6f commit 11848c9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 51 deletions.
101 changes: 61 additions & 40 deletions simplesat/sat/assignment_set.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,89 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import OrderedDict

import six


class _MISSING(object):
def __str__(self):
return '<MISSING>'
MISSING = _MISSING()


class AssignmentSet(object):

"""A collection of literals and their assignments."""

def __init__(self, assignments=None):
# Changelog is a dict of id -> (original value, new value)
# FIXME: Verify that we really need ordering here
self._data = OrderedDict()
self._data = {}
self._orig = {}
self._seen = set()
self._cached_changelog = None
self._assigned_literals = set()
self._assigned_ids = set()
self.new_keys = set()
for k, v in (assignments or {}).items():
self[k] = v

def __setitem__(self, key, value):
assert key > 0

prev_value = self.get(key)
abskey = abs(key)

if prev_value is not None:
self._assigned_literals.difference_update((key, -key))
if abskey not in self._seen:
self.new_keys.add(abskey)

if value is not None:
self._assigned_literals.add(key if value else -key)
if value is None:
del self[key]
else:
self._update_diff(key, value)
self._data[key] = value
self._data[-key] = not value
self._assigned_ids.add(abs(key))

self._update_diff(key, value)
self._data[key] = value
self._seen.add(abskey)

def __delitem__(self, key):
self._update_diff(key, MISSING)
prev = self._data.pop(key)
if prev is not None:
self._assigned_literals.difference_update((key, -key))
self._seen.discard(abs(key))
if key not in self._data:
return
self._update_diff(key, None)
del self._data[key]
del self._data[-key]
self._assigned_ids.discard(abs(key))

def __getitem__(self, key):
return self._data[key]
val = self._data.get(key)
if val is None and abs(key) not in self._seen:
raise KeyError(key)
return val

def get(self, key, default=None):
return self._data.get(key, default)

def __len__(self):
return len(self._data)
return len(self._seen)

def __iter__(self):
return iter(self.keys())

def __contains__(self, key):
return key in self._data
return abs(key) in self._seen

def items(self):
return list(self._data.items())
return sorted(
(k, self._data.get(k))
for k in self._seen)

def iteritems(self):
return six.iteritems(self._data)
return iter(self.items())

def keys(self):
return list(self._data.keys())
return [k for k, _ in self.items()]

def values(self):
return list(self._data.values())
return [v for _, v in self.items()]

def _update_diff(self, key, value):
prev = self._data.get(key, MISSING)
# This must be called before _data is updated
if key < 0 and value is not None:
key = -key
value = not value
prev = self._data.get(key)
self._orig.setdefault(key, prev)
# If a value changes, dump the cached changelog
self._cached_changelog = None
Expand All @@ -81,7 +93,7 @@ def get_changelog(self):
self._cached_changelog = {
key: (old, new)
for key, old in six.iteritems(self._orig)
for new in [self._data.get(key, MISSING)]
for new in [self._data.get(key)]
if new != old
}
return self._cached_changelog
Expand All @@ -90,24 +102,33 @@ def consume_changelog(self):
old = self.get_changelog()
self._orig = {}
self._cached_changelog = {}
self.new_keys.clear()
return old

def copy(self):
new = AssignmentSet()
new._data = self._data.copy()
new._orig = self._orig.copy()
new._assigned_literals = self._assigned_literals.copy()
new._seen = self._seen.copy()
new._assigned_ids = self._assigned_ids.copy()
new.new_keys = self.new_keys.copy()
return new

def to_dict(self):
return dict(self.items())

def value(self, lit):
""" Return the value of literal in terms of the positive. """
if lit in self._assigned_literals:
return True
elif -lit in self._assigned_literals:
return False
else:
return None
""" Return the value of literal. """
return self._data.get(lit)

@property
def num_assigned(self):
return len(self._assigned_literals)
return len(self._assigned_ids)

@property
def assigned_ids(self):
return self._assigned_ids

@property
def unassigned_ids(self):
return self._seen.difference(self._assigned_ids)
32 changes: 24 additions & 8 deletions simplesat/sat/tests/test_assignment_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import unittest

from ..assignment_set import AssignmentSet, MISSING
from ..assignment_set import AssignmentSet


class TestAssignmentSet(unittest.TestCase):
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_container(self):
del AS[5]
self.assertNotIn(5, AS)

expected = [(1, True), (2, False), (4, None), (3, True)]
expected = [(1, True), (2, False), (3, True), (4, None)]

manual_result = list(zip(AS.keys(), AS.values()))
self.assertEqual(AS.items(), expected)
Expand Down Expand Up @@ -94,9 +94,9 @@ def test_copy(self):
copied = AS.copy()

self.assertIsNot(copied._data, AS._data)
self.assertEqual(copied._data, expected)
self.assertEqual(copied.to_dict(), expected)

expected = {k: MISSING for k in expected}
expected = {k: None for k, v in expected.items() if v is not None}

self.assertIsNot(copied._orig, AS._orig)
self.assertEqual(copied._orig, expected)
Expand All @@ -122,6 +122,22 @@ def test_value(self):
self.assertIs(AS.value(3), None)
self.assertIs(AS.value(-3), None)

del AS[2]
self.assertIs(AS.value(-2), None)
self.assertIs(AS.value(2), None)

AS[1] = None
self.assertIs(AS.value(-1), None)
self.assertIs(AS.value(1), None)

AS[3] = True
self.assertIs(AS.value(-3), False)
self.assertIs(AS.value(3), True)

AS[3] = False
self.assertIs(AS.value(-3), True)
self.assertIs(AS.value(3), False)

def test_getitem(self):
AS = AssignmentSet()

Expand All @@ -147,15 +163,15 @@ def test_changelog(self):

AS[1] = None

expected = {1: (MISSING, None)}
expected = {}
self.assertEqual(AS.get_changelog(), expected)

AS[2] = True
expected[2] = (MISSING, True)
expected[2] = (None, True)
self.assertEqual(AS.get_changelog(), expected)

AS[2] = False
expected[2] = (MISSING, False)
expected[2] = (None, False)
self.assertEqual(AS.get_changelog(), expected)

del AS[2]
Expand All @@ -177,5 +193,5 @@ def test_changelog(self):
self.assertEqual(AS.get_changelog(), expected)

del AS[1]
expected = {1: (None, MISSING)}
expected = {}
self.assertEqual(AS.get_changelog(), expected)
6 changes: 3 additions & 3 deletions simplesat/sat/tests/test_minisat.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_propagation_with_queue(self):

# Then
self.assertIsNone(conflict)
self.assertEqual(s.assignments._data,
self.assertEqual(s.assignments.to_dict(),
{1: True, 2: False, 3: None, 4: None})
self.assertEqual(s.trail, [-2, 1])
six.assertCountEqual(self, s.watches[-1], [cl1, cl2])
Expand All @@ -302,7 +302,7 @@ def test_propagation_with_queue_multiple_implications(self):

# Then
self.assertIsNone(conflict)
self.assertEqual(s.assignments._data,
self.assertEqual(s.assignments.to_dict(),
{1: False, 2: False, 3: False, 4: False})
self.assertEqual(s.trail, [-1, -2, -3, -4])

Expand Down Expand Up @@ -377,7 +377,7 @@ def test_undo_one(self):
s.undo_one()

# Then
self.assertEqual(s.assignments._data, {1: None, 2: None, 3: None})
self.assertEqual(s.assignments.to_dict(), {1: None, 2: None, 3: None})
self.assertEqual(s.trail, [1, 2])

def test_cancel(self):
Expand Down

0 comments on commit 11848c9

Please sign in to comment.