diff --git a/CHANGES.rst b/CHANGES.rst index 11a728c..3b3fcd2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,9 @@ Enhancements * Details relating to unsatisfiable scenarios are captured in an ``UNSAT`` object and attached to the ``SatisifiabilityError`` raised (#101). +* Adds a second policy based on a priority queue (#131) +* Major speed improvements in the policies and assignment tracker (#131) + Bugs Fixed ---------- @@ -18,6 +21,7 @@ Bugs Fixed * Some sort operations that were using non-unique keys have been fixed (#101). * Assumptions are now represented as an empty Clause object (#101). + Version 0.1.0 ============= diff --git a/scripts/solve.py b/scripts/solve.py index 33724f5..7144962 100644 --- a/scripts/solve.py +++ b/scripts/solve.py @@ -36,7 +36,9 @@ def solve_and_print(request, remote_repositories, installed_repository, print(e.unsat._find_requirement_time.pretty(fmt), file=sys.stderr) if debug: - report = solver._policy._log_report(detailed=(debug > 1)) + counts, hist = solver._policy._log_histogram() + print(hist, file=sys.stderr) + report = solver._policy._log_report(with_assignments=debug > 1) print(report, file=sys.stderr) print(solver._last_rules_time.pretty(fmt), file=sys.stderr) print(solver._last_solver_init_time.pretty(fmt), file=sys.stderr) diff --git a/simplesat/priority_queue.py b/simplesat/priority_queue.py new file mode 100644 index 0000000..15df6d0 --- /dev/null +++ b/simplesat/priority_queue.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from collections import defaultdict +from functools import partial +from heapq import heappush, heappop +from itertools import count + +import six + + +class _REMOVED_TASK(object): + pass +REMOVED_TASK = _REMOVED_TASK() + + +class PriorityQueue(object): + """ A priority queue implementation that supports reprioritizing or + removing tasks, given that tasks are unique. + + Borrowed from: https://docs.python.org/3/library/heapq.html + """ + + def __init__(self): + # list of entries arranged in a heap + self._pq = [] + # mapping of tasks to entries + self._entry_finder = {} + # unique id genrator for tie-breaking + self._next_id = partial(next, count()) + + def __len__(self): + return len(self._entry_finder) + + def __bool__(self): + return bool(len(self)) + + def __contains__(self, task): + return task in self._entry_finder + + def clear(self): + self._pq = [] + self._entry_finder = {} + + def push(self, task, priority=0): + "Add a new task or update the priority of an existing task" + return self._push(priority, self._next_id(), task) + + def peek(self): + """ Return the task with the lowest priority. + + This will pop and repush if a REMOVED task is found. + """ + if not self._pq: + raise KeyError('peek from an empty priority queue') + entry = self._pq[0] + if entry[-1] is REMOVED_TASK: + entry = self._pop() + self._push(*entry) + return entry[-1] + + def pop(self): + 'Remove and return the lowest priority task. Raise KeyError if empty.' + _, _, task = self._pop() + return task + + def pop_many(self, n=None): + """ Return a list of length n of popped elements. If n is not + specified, pop the entire queue. """ + if n is None: + n = len(self) + result = [] + for _ in range(n): + result.append(self.pop()) + return result + + def discard(self, task): + "Remove an existing task if present. If not, do nothing." + try: + self.remove(task) + except KeyError: + pass + + def remove(self, task): + "Remove an existing task. Raise KeyError if not found." + entry = self._entry_finder.pop(task) + entry[-1] = REMOVED_TASK + + def _pop(self): + while self._pq: + entry = heappop(self._pq) + if entry[-1] is not REMOVED_TASK: + del self._entry_finder[entry[-1]] + return entry + raise KeyError('pop from an empty priority queue') + + def _push(self, priority, task_id, task): + if task in self: + o_priority, _, o_task = self._entry_finder[task] + # Still check the task, which might now be REMOVED + if priority == o_priority and task == o_task: + # We're pushing something we already have, do nothing + return + else: + # Make space for the new entry + self.remove(task) + entry = [priority, task_id, task] + self._entry_finder[task] = entry + heappush(self._pq, entry) + + +class GroupPrioritizer(object): + + """ A helper for assigning hierarchical priorities to items + according to priority groups. """ + + def __init__(self, order_key_func=lambda x: x): + """ + Parameters + ---------- + `order_key_func` : callable + used to sort items in each group. + """ + self.key_func = order_key_func + self._priority_groups = defaultdict(set) + self._item_priority = {} + self.known = frozenset() + self.dirty = True + + def __contains__(self, item): + return item in self._item_priority + + def __getitem__(self, item): + "Return the priority of an item." + if self.dirty: + self._prioritize() + return self._item_priority[item] + + def get(self, item, default=None): + if item in self: + return self[item] + return default + + def items(self): + "Return an (item, priority) iterator for all items." + if self.dirty: + self._prioritize() + return six.iteritems(self._item_priority) + + def update(self, items, group): + """Add `items` to the `group`, remove `items` from all other groups, + and update all priority values.""" + self.known = self.known.union(items) + for _group, _items in self._priority_groups.items(): + if _group != group: + _items.difference_update(items) + self._priority_groups[group].update(items) + self.dirty = True + + def group(self, group): + "Return the set of items in `group`." + if group not in self._priority_groups: + raise KeyError(repr(group)) + return self._priority_groups[group] + + def _prioritize(self): + item_priority = {} + + for group, items in six.iteritems(self._priority_groups): + ordered_items = sorted(items, key=self.key_func) + for rank, item in enumerate(ordered_items): + priority = (group, rank) + item_priority[item] = priority + + self._item_priority = item_priority + self.dirty = False diff --git a/simplesat/sat/assignment_set.py b/simplesat/sat/assignment_set.py index d36b8d2..ad0998f 100644 --- a/simplesat/sat/assignment_set.py +++ b/simplesat/sat/assignment_set.py @@ -1,17 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from collections import OrderedDict - import six -class _MISSING(object): - def __str__(self): - return '' -MISSING = _MISSING() - - class AssignmentSet(object): """A collection of literals and their assignments.""" @@ -19,59 +11,79 @@ class AssignmentSet(object): def __init__(self, assignments=None): # Changelog is a dict of id -> (original value, new value) # FIXME: Verify that we really need ordering here - self._data = OrderedDict() + self._data = {} self._orig = {} + self._seen = set() self._cached_changelog = None - self._assigned_literals = set() + self._assigned_ids = set() + self.new_keys = set() for k, v in (assignments or {}).items(): self[k] = v def __setitem__(self, key, value): - assert key > 0 - prev_value = self.get(key) + abskey = abs(key) - if prev_value is not None: - self._assigned_literals.difference_update((key, -key)) + if abskey not in self._seen: + self.new_keys.add(abskey) - if value is not None: - self._assigned_literals.add(key if value else -key) + if value is None: + del self[key] + else: + self._update_diff(key, value) + self._data[key] = value + self._data[-key] = not value + self._assigned_ids.add(abs(key)) - self._update_diff(key, value) - self._data[key] = value + self._seen.add(abskey) def __delitem__(self, key): - self._update_diff(key, MISSING) - prev = self._data.pop(key) - if prev is not None: - self._assigned_literals.difference_update((key, -key)) + self._seen.discard(abs(key)) + if key not in self._data: + return + self._update_diff(key, None) + del self._data[key] + del self._data[-key] + self._assigned_ids.discard(abs(key)) def __getitem__(self, key): - return self._data[key] + val = self._data.get(key) + if val is None and abs(key) not in self._seen: + raise KeyError(key) + return val def get(self, key, default=None): return self._data.get(key, default) def __len__(self): - return len(self._data) + return len(self._seen) + + def __iter__(self): + return iter(self.keys()) def __contains__(self, key): - return key in self._data + return abs(key) in self._seen def items(self): - return list(self._data.items()) + return sorted( + (k, self._data.get(k)) + for k in self._seen) def iteritems(self): - return six.iteritems(self._data) + return iter(self.items()) def keys(self): - return list(self._data.keys()) + return [k for k, _ in self.items()] def values(self): - return list(self._data.values()) + return [v for _, v in self.items()] def _update_diff(self, key, value): - prev = self._data.get(key, MISSING) + # This must be called before _data is updated + if key < 0 and value is not None: + key = -key + value = not value + prev = self._data.get(key) self._orig.setdefault(key, prev) # If a value changes, dump the cached changelog self._cached_changelog = None @@ -81,7 +93,7 @@ def get_changelog(self): self._cached_changelog = { key: (old, new) for key, old in six.iteritems(self._orig) - for new in [self._data.get(key, MISSING)] + for new in [self._data.get(key)] if new != old } return self._cached_changelog @@ -90,24 +102,33 @@ def consume_changelog(self): old = self.get_changelog() self._orig = {} self._cached_changelog = {} + self.new_keys.clear() return old def copy(self): new = AssignmentSet() new._data = self._data.copy() new._orig = self._orig.copy() - new._assigned_literals = self._assigned_literals.copy() + new._seen = self._seen.copy() + new._assigned_ids = self._assigned_ids.copy() + new.new_keys = self.new_keys.copy() return new + def to_dict(self): + return dict(self.items()) + def value(self, lit): - """ Return the value of literal in terms of the positive. """ - if lit in self._assigned_literals: - return True - elif -lit in self._assigned_literals: - return False - else: - return None + """ Return the value of literal. """ + return self._data.get(lit) @property def num_assigned(self): - return len(self._assigned_literals) + return len(self._assigned_ids) + + @property + def assigned_ids(self): + return self._assigned_ids + + @property + def unassigned_ids(self): + return self._seen.difference(self._assigned_ids) diff --git a/simplesat/sat/policy/__init__.py b/simplesat/sat/policy/__init__.py index 2e8da63..12354a6 100644 --- a/simplesat/sat/policy/__init__.py +++ b/simplesat/sat/policy/__init__.py @@ -4,5 +4,9 @@ from .undetermined_clause_policy import ( LoggedUndeterminedClausePolicy, UndeterminedClausePolicy ) +from .priority_queue_policy import ( + LoggedPriorityQueuePolicty, PriorityQueuePolicy +) -InstalledFirstPolicy = LoggedUndeterminedClausePolicy +# InstalledFirstPolicy = LoggedUndeterminedClausePolicy +InstalledFirstPolicy = LoggedPriorityQueuePolicty diff --git a/simplesat/sat/policy/policy.py b/simplesat/sat/policy/policy.py index b19fd3d..2a2e16a 100644 --- a/simplesat/sat/policy/policy.py +++ b/simplesat/sat/policy/policy.py @@ -37,7 +37,7 @@ class DefaultPolicy(IPolicy): def add_requirements(self, assignments): pass - def get_next_package_id(self, assignments, _): + def get_next_package_id(self, assignments, *_): # Given a dictionary of partial assignments, get an undecided variable # to be decided next. undecided = ( diff --git a/simplesat/sat/policy/policy_logger.py b/simplesat/sat/policy/policy_logger.py index 8e9e859..64c8d89 100644 --- a/simplesat/sat/policy/policy_logger.py +++ b/simplesat/sat/policy/policy_logger.py @@ -10,8 +10,8 @@ class PolicyLogger(IPolicy): def __init__(self, policy, args=None, kwargs=None): self._policy = policy self._log_pool = args[0] - self._log_installed = getattr(policy, '_installed_ids', set()).copy() - self._log_preferred = getattr(policy, '_preferred_ids', set()).copy() + self._log_installed = set(getattr(policy, '_installed_ids', ())) + self._log_preferred = set(getattr(policy, '_preferred_ids', ())) self._log_args = args self._log_kwargs = kwargs self._log_required = [] @@ -27,7 +27,6 @@ def get_next_package_id(self, assignments, clauses): def add_requirements(self, package_ids): self._log_required.extend(package_ids) - self._log_preferred.difference_update(package_ids) self._log_installed.difference_update(package_ids) self._policy.add_requirements(package_ids) @@ -36,7 +35,7 @@ def _log_histogram(self, pkg_ids=None): pkg_ids = map(abs, self._log_suggestions) c = Counter(pkg_ids) lines = ( - "{:>25} {}".format(self._log_pretty_pkg_id(k), v) + "{:>25} {:>5}".format(self._log_pretty_pkg_id(k), v) for k, v in c.most_common() ) pretty = '\n'.join(lines) @@ -52,7 +51,7 @@ def _log_pretty_pkg_id(self, pkg_id): repo = 'installed' return "{:{fill}<30} {:3} {}".format(name_ver, pkg_id, repo, fill=fill) - def _log_report(self, detailed=True): + def _log_report(self, with_assignments=True): def pkg_name(pkg_id): return pkg_key(pkg_id)[0] @@ -64,7 +63,7 @@ def pkg_key(pkg_id): ids = map(abs, self._log_suggestions) report = [] changes = [] - if self._log_assignment_changes: + if self._log_assignment_changes and with_assignments: for pkg, change in self._log_assignment_changes[0].items(): name = self._log_pretty_pkg_id(pkg) if change[1] is not None: @@ -72,33 +71,31 @@ def pkg_key(pkg_id): report.append('\n'.join(changes)) required = set(self._log_required) - preferred = set(self._log_preferred) installed = set(self._log_installed) for (i, sugg) in enumerate(ids): pretty = self._log_pretty_pkg_id(sugg) R = 'R' if sugg in required else ' ' - P = 'P' if sugg in preferred else ' ' I = 'I' if sugg in installed else ' ' - changes = [] + change_str = "" try: - change_items = self._log_assignment_changes[i + 1].items() - if detailed: - change_items = sorted( - change_items, key=lambda p: pkg_key(p[0])) - for pkg, change in change_items: + items = self._log_assignment_changes[i + 1].items() + sorted_items = sorted(items, key=lambda p: pkg_key(p[0])) + if with_assignments: + changes = [] + for pkg, change in sorted_items: if pkg_name(pkg) != pkg_name(sugg): _pretty = self._log_pretty_pkg_id(pkg) fro, to = map(str, change) msg = "{:10} - {:10} : {}" changes.append(msg.format(fro, to, _pretty)) - if changes: - changes = '\n\t\t'.join([''] + changes) - else: - changes = "" + if changes: + change_str = '\n\t'.join([''] + changes) + msg = "{:>4} {}{} - {}{}" + report.append(msg.format(i, R, I, pretty, change_str)) + if any(v[1] is None for _, v in sorted_items): + report.append("BACKTRACKED\n") except IndexError: - changes = "" - msg = "{:>4} {}{}{} - {}{}" - report.append(msg.format(i, R, P, I, pretty, changes)) + pass return '\n'.join(report) diff --git a/simplesat/sat/policy/priority_queue_policy.py b/simplesat/sat/policy/priority_queue_policy.py new file mode 100644 index 0000000..80bf3be --- /dev/null +++ b/simplesat/sat/policy/priority_queue_policy.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- + +from collections import defaultdict + +import six + +from simplesat.constraints.requirement import InstallRequirement +from simplesat.utils import DefaultOrderedDict, toposort, transitive_neighbors +from simplesat.priority_queue import PriorityQueue, GroupPrioritizer +from .policy import IPolicy +from .policy_logger import LoggedPolicy + + +class PriorityQueuePolicy(IPolicy): + + """ An IPolicy that uses a priority queue to determine which package id + should be suggested next. + + Packages are split into groups: + + 1. currently installed, + 2. explicitly specified as a requirement, + 3. everything else, + + where each group is arranged in topological order by dependency + relationships and then descending order by version number. + The groups are then searched in order and the first unassigned package id + is suggested. + """ + + def __init__(self, pool, installed_repository, prefer_installed=True): + self._pool = pool + self._installed_ids = set(map(pool.package_id, installed_repository)) + + package_ids = pool.package_ids + self._package_id_to_rank = None # set the first time we check + self._all_ids = set(package_ids) + self._required_ids = set() + self._name_to_package_ids = self._group_packages_by_name(package_ids) + + def priority_func(p): + return self._package_id_to_rank[p] + + self._unassigned_pkg_ids = PriorityQueue() + + self.DEFAULT = 0 + if prefer_installed: + self.INSTALLED = -2 + self.REQUIRED = -1 + else: + self.REQUIRED = -1 + self.INSTALLED = self.DEFAULT + + self._prioritizer = GroupPrioritizer(priority_func) + self._add_packages(self._installed_ids.copy(), self.INSTALLED) + + def add_requirements(self, package_ids): + self._required_ids.update(package_ids) + if self.REQUIRED < self.INSTALLED: + self._installed_ids.difference_update(package_ids) + else: + package_ids = set(package_ids).difference(self._installed_ids) + self._add_packages(package_ids, self.REQUIRED) + + def get_next_package_id(self, assignments, clauses): + self._update_cache_from_assignments(assignments) + # Grab the most interesting looking currently unassigned id + p_id = self._unassigned_pkg_ids.peek() + return p_id + + def _add_packages(self, package_ids, group): + prioritizer = self._prioritizer + prioritizer.update(package_ids, group=group) + + # Removing an item from an ordering always maintains the ordering, + # so we only need to update priorities on groups that had items added + for pkg_id in prioritizer.group(group): + if pkg_id in self._unassigned_pkg_ids: + self._unassigned_pkg_ids.push(pkg_id, prioritizer[pkg_id]) + + def pkg_key(self, package_id): + """ Return the key used to compare two packages. """ + package = self._pool.id_to_package(package_id) + try: + installed = package.repository_info.name == 'installed' + except AttributeError: + installed = False + return (package.version, installed) + + def _rank_packages(self, package_ids): + """ Return a dictionary of package_id to priority rank. + + Currently we build a dependency tree of all the relevant packages and + then rank them topologically, starting with those at the top. + + This strategy causes packages which force more assignments via + unit propagation in the solver to be preferred. + """ + pool = self._pool + R = InstallRequirement + + # The direct dependencies of each package + dependencies = defaultdict(set) + for package_id in package_ids: + dependencies[package_id].update( + pool.package_id(package) + for cons in pool.id_to_package(package_id).install_requires + for package in pool.what_provides(R.from_constraints(cons)) + ) + + # This is a flattened version of `dependencies` above + transitive = transitive_neighbors(dependencies) + + packages_by_name = self._group_packages_by_name(package_ids) + + # Some packages have unversioned dependencies, such as simply 'pandas'. + # This can produce cycles in the dependency graph which much be removed + # before topological sorting can be done. + # The strategy is to ignore the dependencies of any package that is + # present in its own transitive dependency list + removed_deps = [] + for package_id in package_ids: + package = pool.id_to_package(package_id) + deps = dependencies[package_id] + package_group = packages_by_name[package.name] + for dep in list(deps): + circular = transitive[dep].intersection(package_group) + if circular: + packages = [pool.id_to_package(p) for p in circular] + depkg = pool.id_to_package(dep) + pkg_strings = [ + "{}-{}".format(pkg.name, pkg.version) + for pkg in packages + ] + msg = "Circular Deps: {}-{} -> {}-{} -> {}".format( + package.name, package.version, + depkg.name, depkg.version, + pkg_strings + ) + removed_deps.append(msg) + deps.remove(dep) + + # Mark packages as depending on older versions of themselves so that + # they will come out first in the toposort + for package_id in package_ids: + package = pool.id_to_package(package_id) + package_group = packages_by_name[package.name] + idx = package_group.index(package_id) + other_older = package_group[:idx + 1] + dependencies[package_id].update(other_older) + + # Finally toposort the packages, preferring higher version and + # already-installed packages to break ties + ordered = [ + package_id + for group in tuple(toposort(dependencies)) + for package_id in sorted(group, key=self.pkg_key, reverse=True) + ] + + package_id_to_rank = { + package_id: rank + for rank, package_id in enumerate(ordered) + } + + return package_id_to_rank + + def _group_packages_by_name(self, package_ids): + """ Return a dictionary from package name to all package ids + corresponding to packages with that name. """ + pool = self._pool + + name_map = DefaultOrderedDict(list) + for package_id in package_ids: + package = pool.id_to_package(package_id) + name_map[package.name].append(package_id) + + name_to_package_ids = {} + + for name, package_ids in name_map.items(): + ordered = sorted(package_ids, key=self.pkg_key, reverse=True) + name_to_package_ids[name] = ordered + + return name_to_package_ids + + def _update_cache_from_assignments(self, assignments): + new_keys = assignments.new_keys.copy() + changelog = assignments.consume_changelog() + + if new_keys: + unknown_ids = new_keys.difference(self._prioritizer.known) + self._all_ids.update(new_keys) + self._package_id_to_rank = self._rank_packages(self._all_ids) + self._prioritizer.update(unknown_ids, group=self.DEFAULT) + + # Newly unassigned + self._unassigned_pkg_ids.clear() + for key in assignments.unassigned_ids: + priority = self._prioritizer[key] + self._unassigned_pkg_ids.push(key, priority=priority) + else: + for key, (old, new) in six.iteritems(changelog): + if new is None: + # Newly unassigned + priority = self._prioritizer[key] + self._unassigned_pkg_ids.push(key, priority=priority) + elif old is None: + # No longer unassigned (because new is not None) + self._unassigned_pkg_ids.remove(key) + + # The remaining case is True -> False, False -> True or + # MISSING -> (True|False) + + # A very cheap sanity check + ours = len(self._unassigned_pkg_ids) + theirs = len(assignments) - assignments.num_assigned + has_new_keys = len(new_keys) + msg = "We failed to track variable assignments {} {} {}" + assert ours == theirs, msg.format(ours, theirs, has_new_keys) + + +LoggedPriorityQueuePolicty = LoggedPolicy(PriorityQueuePolicy) diff --git a/simplesat/sat/policy/undetermined_clause_policy.py b/simplesat/sat/policy/undetermined_clause_policy.py index 1617b4d..4bfbf52 100644 --- a/simplesat/sat/policy/undetermined_clause_policy.py +++ b/simplesat/sat/policy/undetermined_clause_policy.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from collections import defaultdict + import six from simplesat.utils import DefaultOrderedDict @@ -17,8 +19,10 @@ class UndeterminedClausePolicy(IPolicy): def __init__(self, pool, installed_repository, prefer_installed=True): self._pool = pool self.prefer_installed = prefer_installed - self._installed_ids = set( - pool.package_id(package) for package in installed_repository + by_version = six.functools.partial(pkg_id_to_version, self._pool) + self._installed_ids = sorted( + (pool.package_id(package) for package in installed_repository), + key=by_version ) self._preferred_package_ids = { self._package_key(package_id): package_id @@ -26,6 +30,9 @@ def __init__(self, pool, installed_repository, prefer_installed=True): } self._decision_set = set() self._requirements = set() + self._unsatisfied_clauses = set() + self._id_to_clauses = defaultdict(list) + self._all_ids = set() def _package_key(self, package_id): package = self._pool.id_to_package(package_id) @@ -34,30 +41,50 @@ def _package_key(self, package_id): def add_requirements(self, package_ids): self._requirements.update(package_ids) + def _update_cache_from_assignments(self, assignments): + changelog = assignments.consume_changelog() + for key in six.iterkeys(changelog): + for clause in self._id_to_clauses[key]: + if any(assignments.value(l) for l in clause.lits): + self._unsatisfied_clauses.discard(clause) + else: + self._unsatisfied_clauses.add(clause) + + def _build_id_to_clauses(self, clauses): + """ Return a mapping from package ids to a list of clauses containing + that id. + """ + table = defaultdict(list) + for c in clauses: + for l in c.lits: + table[abs(l)].append(c) + self._all_ids = set(six.iterkeys(table)) + return dict(table) + def get_next_package_id(self, assignments, clauses): """Get the next unassigned package. """ + if assignments.new_keys: + self._id_to_clauses = self._build_id_to_clauses(clauses) + self._refresh_decision_set(assignments) candidate_id = None best = self._best_candidate if self.prefer_installed: - candidate_id = best(self._installed_ids, assignments) + candidate_id = self._best_sorted_candidate( + self._installed_ids, assignments) - candidate_id = ( - candidate_id or - self._best_candidate(self._requirements, assignments) or - self._best_candidate(self._decision_set, assignments) - ) + if candidate_id is None: + candidate_id = best(self._requirements, assignments) + + if candidate_id is None: + candidate_id = best(self._decision_set, assignments, update=True) if candidate_id is None: - self._decision_set.clear() - candidate_id = \ - self._handle_empty_decision_set(assignments, clauses) + self._refresh_decision_set(assignments) + candidate_id = best(self._decision_set, assignments) if candidate_id is None: - candidate_id = self._best_candidate( - self._decision_set, - assignments - ) + candidate_id = best(self._all_ids, assignments) assert assignments.get(candidate_id) is None, \ "Trying to assign to a variable which is already assigned." @@ -70,14 +97,19 @@ def get_next_package_id(self, assignments, clauses): return candidate_id def _without_assigned(self, package_ids, assignments): - return set( - pkg_id for pkg_id in package_ids - if assignments.get(pkg_id) is None - ) + return package_ids.difference(assignments.assigned_ids) - def _best_candidate(self, package_ids, assignments): + def _best_sorted_candidate(self, package_ids, assignments): + for p_id in package_ids: + if p_id not in assignments.assigned_ids: + return p_id + + def _best_candidate(self, package_ids, assignments, update=False): by_version = six.functools.partial(pkg_id_to_version, self._pool) unassigned = self._without_assigned(package_ids, assignments) + if update: + package_ids.clear() + package_ids.update(unassigned) try: return max(unassigned, key=by_version) except ValueError: @@ -86,51 +118,26 @@ def _best_candidate(self, package_ids, assignments): def _group_packages_by_name(self, decision_set): installed_packages = [] new_package_map = DefaultOrderedDict(list) + installed_ids = set(self._installed_ids) for package_id in sorted(decision_set): package = self._pool.id_to_package(package_id) - if package_id in self._installed_ids: + if package_id in installed_ids: installed_packages.append(package) else: new_package_map[package.name].append(package) return installed_packages, new_package_map - def _handle_empty_decision_set(self, assignments, clauses): - # TODO inefficient and verbose - - unassigned_ids = set( - literal for literal, status in six.iteritems(assignments) - if status is None + def _refresh_decision_set(self, assignments): + self._update_cache_from_assignments(assignments) + self._decision_set.clear() + self._decision_set.update( + abs(lit) + for clause in self._unsatisfied_clauses + for lit in clause.lits ) - assigned_ids = set(assignments.keys()) - unassigned_ids - - signed_assignments = set() - for variable in assigned_ids: - if assignments[variable]: - signed_assignments.add(variable) - else: - signed_assignments.add(-variable) - - for clause in clauses: - # TODO Need clause.undecided_literals property - if not signed_assignments.isdisjoint(clause.lits): - # Clause is true - continue - - variables = map(abs, clause.lits) - undecided = unassigned_ids.intersection(variables) - self._decision_set.update(lit for lit in undecided) - - if len(self._decision_set) == 0: - # This will happen if the remaining packages are irrelevant for - # the set of rules that we're trying to satisfy. In that case, - # just return one of the undecided IDs. - - # We use min to ensure determinisism - return min(unassigned_ids) - else: - return None + self._decision_set.difference_update(assignments.assigned_ids) LoggedUndeterminedClausePolicy = LoggedPolicy(UndeterminedClausePolicy) diff --git a/simplesat/sat/tests/test_assignment_set.py b/simplesat/sat/tests/test_assignment_set.py index ca347b1..68e51df 100644 --- a/simplesat/sat/tests/test_assignment_set.py +++ b/simplesat/sat/tests/test_assignment_set.py @@ -3,7 +3,7 @@ import unittest -from ..assignment_set import AssignmentSet, MISSING +from ..assignment_set import AssignmentSet class TestAssignmentSet(unittest.TestCase): @@ -66,7 +66,7 @@ def test_container(self): del AS[5] self.assertNotIn(5, AS) - expected = [(1, True), (2, False), (4, None), (3, True)] + expected = [(1, True), (2, False), (3, True), (4, None)] manual_result = list(zip(AS.keys(), AS.values())) self.assertEqual(AS.items(), expected) @@ -94,9 +94,9 @@ def test_copy(self): copied = AS.copy() self.assertIsNot(copied._data, AS._data) - self.assertEqual(copied._data, expected) + self.assertEqual(copied.to_dict(), expected) - expected = {k: MISSING for k in expected} + expected = {k: None for k, v in expected.items() if v is not None} self.assertIsNot(copied._orig, AS._orig) self.assertEqual(copied._orig, expected) @@ -122,6 +122,22 @@ def test_value(self): self.assertIs(AS.value(3), None) self.assertIs(AS.value(-3), None) + del AS[2] + self.assertIs(AS.value(-2), None) + self.assertIs(AS.value(2), None) + + AS[1] = None + self.assertIs(AS.value(-1), None) + self.assertIs(AS.value(1), None) + + AS[3] = True + self.assertIs(AS.value(-3), False) + self.assertIs(AS.value(3), True) + + AS[3] = False + self.assertIs(AS.value(-3), True) + self.assertIs(AS.value(3), False) + def test_getitem(self): AS = AssignmentSet() @@ -147,15 +163,15 @@ def test_changelog(self): AS[1] = None - expected = {1: (MISSING, None)} + expected = {} self.assertEqual(AS.get_changelog(), expected) AS[2] = True - expected[2] = (MISSING, True) + expected[2] = (None, True) self.assertEqual(AS.get_changelog(), expected) AS[2] = False - expected[2] = (MISSING, False) + expected[2] = (None, False) self.assertEqual(AS.get_changelog(), expected) del AS[2] @@ -177,5 +193,5 @@ def test_changelog(self): self.assertEqual(AS.get_changelog(), expected) del AS[1] - expected = {1: (None, MISSING)} + expected = {} self.assertEqual(AS.get_changelog(), expected) diff --git a/simplesat/sat/tests/test_minisat.py b/simplesat/sat/tests/test_minisat.py index de4bd20..a18a923 100644 --- a/simplesat/sat/tests/test_minisat.py +++ b/simplesat/sat/tests/test_minisat.py @@ -278,7 +278,7 @@ def test_propagation_with_queue(self): # Then self.assertIsNone(conflict) - self.assertEqual(s.assignments._data, + self.assertEqual(s.assignments.to_dict(), {1: True, 2: False, 3: None, 4: None}) self.assertEqual(s.trail, [-2, 1]) six.assertCountEqual(self, s.watches[-1], [cl1, cl2]) @@ -302,7 +302,7 @@ def test_propagation_with_queue_multiple_implications(self): # Then self.assertIsNone(conflict) - self.assertEqual(s.assignments._data, + self.assertEqual(s.assignments.to_dict(), {1: False, 2: False, 3: False, 4: False}) self.assertEqual(s.trail, [-1, -2, -3, -4]) @@ -377,7 +377,7 @@ def test_undo_one(self): s.undo_one() # Then - self.assertEqual(s.assignments._data, {1: None, 2: None, 3: None}) + self.assertEqual(s.assignments.to_dict(), {1: None, 2: None, 3: None}) self.assertEqual(s.trail, [1, 2]) def test_cancel(self): diff --git a/simplesat/tests/epd_full_conflict.yaml b/simplesat/tests/epd_full_conflict.yaml index 975539b..99fbbfe 100644 --- a/simplesat/tests/epd_full_conflict.yaml +++ b/simplesat/tests/epd_full_conflict.yaml @@ -2023,9 +2023,11 @@ failure: requirements: ['EPD', 'numpy > 1.8.0-0'] raw: | Conflicting requirements: - Requirements: 'EPD' <- 'numpy == 1.6.0-5' - EPD-7.1-1 requires (+numpy-1.6.0-5) + Requirements: 'EPD' + Install command rule (+EPD-7.0-1 | +EPD-7.0-2 | +EPD-7.1-1 | +EPD-7.1-2 | +EPD-7.2-1 | +EPD-7.2-2 | +EPD-7.3-1 | +EPD-7.3-2) + Requirements: 'EPD' <- 'numpy == 1.6.1-2' + EPD-7.3-1 requires (+numpy-1.6.1-2) Requirements: 'EPD' <- 'SimPy == 2.1.0-2' <- 'numpy' <- 'numpy' - Can only install one of: (+numpy-1.8.1-1 | +numpy-1.6.0-5) + Can only install one of: (+numpy-1.8.1-1 | +numpy-1.6.1-2) Requirements: 'numpy > 1.8.0-0' Install command rule (+numpy-1.8.0-1 | +numpy-1.8.0-2 | +numpy-1.8.0-3 | +numpy-1.8.1-1) diff --git a/simplesat/utils/graph.py b/simplesat/utils/graph.py index 04f809d..e41d0d1 100644 --- a/simplesat/utils/graph.py +++ b/simplesat/utils/graph.py @@ -18,13 +18,13 @@ def toposort(nodes_to_edges): each subsequent set consists of items that depend upon items in the preceeding sets. - >>> print '\\n'.join(repr(sorted(x)) for x in toposort2({ + >>> print('\\n'.join(repr(sorted(x)) for x in toposort2({ ... 2: set([11]), ... 9: set([11,8]), ... 10: set([11,3]), ... 11: set([7,5]), ... 8: set([7,3]), - ... })) + ... }))) [3, 5, 7] [8, 11] [2, 9, 10]