Skip to content

Commit

Permalink
Merge branch 'fix-cython-pickle' into release/0.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
szymonlopaciuk committed Dec 5, 2023
2 parents 27b8937 + e7f6171 commit 9ff27fc
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 22 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@
*.swp
*.egg-info
.DS_Store

*.so
.pymon
.ipynb_checkpoints
**/.coverage
**/.coverage.*
**/checkpoint_restart.dat
build/
30 changes: 29 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# copyright ############################### #
# This file is part of the Xdeps Package. #
# Copyright (c) CERN, 2023. #
# ######################################### #
import numpy as np
import pickle
import pytest

import xdeps as xd
Expand Down Expand Up @@ -84,7 +89,7 @@ def example_manager():
# │
# ╭ [d[0]]─┐ │
# container2 │ │ │
# ╰ [d[1]]─┴───►(e[0]=d[0]+d[1])───►[e[0]]─┴─►(j=h+e[0])─►[j]
# ╰ [d[1]]─┴───►(e[0]=d[0]*d[1])───►[e[0]]─┴─►(j=h+e[0])─►[j]

ref1['f'] = ref1['a'] + ref1['b']

Expand Down Expand Up @@ -248,6 +253,29 @@ def test_manager_dump_and_load(example_manager):
assert not new_manager.tasks[new_ref2['j']].expr == old_expr


def test_manager_pickle(example_manager):
manager, _, original_containers = example_manager
original1, original2 = original_containers

pickled = pickle.dumps(manager)
new_manager = pickle.loads(pickled)

ref1 = new_manager.containers['ref1']
container1 = ref1._owner
ref2 = new_manager.containers['ref2']
container2 = ref2._owner

ref1['a'] = 32
ref2['d'][1] = 5

assert container1 is not original1
assert container2 is not original2
assert container1['f'] == 40
assert container1['h'] == 6
assert container2['e'][0] == 10
assert container2['j'] == 16


def test_ref_count():
manager = xd.Manager()
container = {'a': 3, 'b': [None], 'c': None}
Expand Down
1 change: 1 addition & 0 deletions tests/test_xdeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This file is part of the Xdeps Package. #
# Copyright (c) CERN, 2021. #
# ######################################### #
import pickle

import xdeps
from xdeps.tasks import AttrDict
Expand Down
85 changes: 65 additions & 20 deletions xdeps/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
'__dict__',
'__getstate__',
'__setstate__',
'__reduce__',
'__reduce_cython__',
'__wrapped__',
'__array_ufunc__',
Expand Down Expand Up @@ -112,11 +111,20 @@ class BaseRef:
_hash = cython.declare(int, visibility='private')

def __init__(self, *args, **kwargs):
raise TypeError("Cannot instantiate abstract class BaseRef")
# To keep compatibility with pure Python (useful for debugging simpler
# issues), we simulate Cython __cinit__ behaviour with this __init__:
if not is_cythonized():
for base in type(self).__mro__:
cinit = getattr(base, '__cinit__', None)
if cinit:
cinit(self, *args, **kwargs)

def __hash__(self):
return self._hash

def __reduce__(self):
raise TypeError("Cannot pickle an abstract class")

def __eq__(self, other):
"""Check equality of the expressions `self` and `other`.
Expand Down Expand Up @@ -319,11 +327,13 @@ class MutableRef(BaseRef):
_owner = cython.declare(object, visibility='public', value=None)
_key = cython.declare(object, visibility='public', value=None)

def __init__(self, _owner, _key, _manager):
def __cinit__(self, _owner, _key, _manager):
self._owner = _owner
self._key = _key
self._manager = _manager
self._hash = hash((self.__class__.__name__, _owner, _key))
# _hash will depend on the particularities of the subclass, for now
# it is None, which does not matter, as this class should never be
# instantiated.

def __setitem__(self, key, value):
ref = ItemRef(self, key, self._manager)
Expand All @@ -348,14 +358,22 @@ def __setattr__(self, attr, value):
# The above way of setting attributes does not work in Cython,
# as the object does not have a __dict__. We do not really need
# a setter for those though, as the only time we need to
# set a "built-in" attribute is during __init__ or when
# set a "built-in" attribute is during __cinit__ or when
# unpickling, and both of those cases are handled by Cython
# without the usual pythonic call to __setattr__.
raise AttributeError(f"Attribute {attr} is read-only.")

ref = AttrRef(self, attr, self._manager)
self._manager.set_value(ref, value)

def __reduce__(self):
"""Do not store the hash when pickling.
The hash is only guaranteed to be the same for the 'same' refs within
the same python instance, therefore serialising hashes makes no sense.
"""
return type(self), (self._owner, self._key, self._manager)

def _get_dependencies(self, out=None):
if out is None:
out = set()
Expand Down Expand Up @@ -510,11 +528,9 @@ class Ref(MutableRef):
"""
A reference in the top-level container.
"""
def __init__(self, _owner, _key, _manager):
self._owner = _owner
self._key = _key
self._manager = _manager
self._hash = hash((self.__class__.__name__, _key))
def __cinit__(self, _owner, _key, _manager):
# Cython automatically calls __cinit__ in the base classes
self._hash = hash((type(self).__name__, _key))

def __repr__(self):
return self._key
Expand Down Expand Up @@ -556,6 +572,10 @@ def __setattr__(self, attr, value):

@cython.cclass
class AttrRef(MutableRef):
def __cinit__(self, _owner, _key, _manager):
# Cython automatically calls __cinit__ in the base classes
self._hash = hash((type(self).__name__, _owner, _key))

def _get_value(self):
owner = BaseRef._mk_value(self._owner)
attr = BaseRef._mk_value(self._key)
Expand All @@ -567,11 +587,17 @@ def _set_value(self, value):
setattr(owner, attr, value)

def __repr__(self):
assert self._owner is not None
assert self._key is not None
return f"{self._owner}.{self._key}"


@cython.cclass
class ItemRef(MutableRef):
def __cinit__(self, _owner, _key, _manager):
# Cython automatically calls __cinit__ in the base classes
self._hash = hash((type(self).__name__, _owner, _key))

def _get_value(self):
owner = BaseRef._mk_value(self._owner)
item = BaseRef._mk_value(self._key)
Expand All @@ -583,6 +609,7 @@ def _set_value(self, value):
owner[item] = value

def __repr__(self):
assert self._owner is not None
return f"{self._owner}[{repr(self._key)}]"


Expand All @@ -603,7 +630,7 @@ class BinOpExpr(BaseRef):
_lhs = cython.declare(object, visibility='public')
_rhs = cython.declare(object, visibility='public')

def __init__(self, lhs, rhs):
def __cinit__(self, lhs, rhs):
self._lhs = lhs
self._rhs = rhs
self._hash = hash((self.__class__, lhs, rhs))
Expand All @@ -620,6 +647,10 @@ def _get_dependencies(self, out=None):
self._rhs._get_dependencies(out)
return out

def __reduce__(self):
"""Instruct pickle to not pickle the hash."""
return type(self), (self._lhs, self._rhs)

def __repr__(self):
return f"({self._lhs} {self._op_str} {self._rhs})"

Expand All @@ -640,7 +671,7 @@ class UnaryOpExpr(BaseRef):
"""
_arg = cython.declare(object, visibility='public')

def __init__(self, arg):
def __cinit__(self, arg):
self._arg = arg
self._hash = hash((self.__class__, self._arg))

Expand All @@ -655,6 +686,10 @@ def _get_dependencies(self, out=None):
# performance reasons we skip the check.
return self._arg._get_dependencies(out)

def __reduce__(self):
"""Instruct pickle to not pickle the hash."""
return type(self), (self._arg,)

def __repr__(self):
return f"({self._op_str}{self._arg})"

Expand Down Expand Up @@ -891,7 +926,7 @@ class BuiltinRef(BaseRef):
_op = cython.declare(object, visibility='public')
_params = cython.declare(tuple, visibility='public')

def __init__(self, arg, op, params=()):
def __cinit__(self, arg, op, params=()):
self._arg = arg
self._op = op
self._params = params
Expand All @@ -912,6 +947,10 @@ def _get_dependencies(self, out=None):
arg._get_dependencies(out)
return out

def __reduce__(self):
"""Instruct pickle to not pickle the hash."""
return type(self), (self._op, self._args)

def __repr__(self):
op_symbol = OPERATOR_SYMBOLS.get(self._op, self._op.__name__)
return f"{op_symbol}({self._arg})"
Expand All @@ -923,10 +962,13 @@ class CallRef(BaseRef):
_args = cython.declare(tuple, visibility='public')
_kwargs = cython.declare(tuple, visibility='public')

def __init__(self, func, args, kwargs):
def __cinit__(self, func, args, kwargs):
self._func = func
self._args = args
self._kwargs = tuple(kwargs.items())
if isinstance(kwargs, dict):
self._kwargs = tuple(kwargs.items())
else:
self._kwargs = tuple(kwargs)
self._hash = hash((self._func, self._args, self._kwargs))

def _get_value(self):
Expand All @@ -948,17 +990,20 @@ def _get_dependencies(self, out=None):
arg._get_dependencies(out)
return out

def __reduce__(self):
"""Instruct pickle to not pickle the hash."""
return type(self), (self._func, self._args, self._kwargs)

def __repr__(self):
args = []
for aa in self._args:
args.append(repr(aa))
for k, v in self._kwargs:
args.append(f"{k}={v}")
args = [repr(arg) for arg in self._args]
args += [f"{k}={v}" for k, v in self._kwargs]
args = ", ".join(args)

if isinstance(self._func, BaseRef):
fname = repr(self._func)
else:
fname = self._func.__name__

return f"{fname}({args})"


Expand Down

0 comments on commit 9ff27fc

Please sign in to comment.