From a09395a8bb1c4a14173f19c34393f3bad2ccd2ab Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Mon, 27 Aug 2018 13:38:15 +0200 Subject: [PATCH] Refactor+type entities --- kn/base/conf.py | 5 +- kn/fotos/entities.py | 2 + kn/leden/api.py | 2 +- kn/leden/date.py | 6 +- kn/leden/entities.py | 477 +++++++++++++++++++---------- kn/leden/mongo.py | 35 ++- kn/leden/views.py | 12 +- kn/utils/giedo/db.py | 2 +- utils/leden/quarterlyMembership.py | 3 +- 9 files changed, 348 insertions(+), 196 deletions(-) diff --git a/kn/base/conf.py b/kn/base/conf.py index fef6e8971..8b6fc27e2 100644 --- a/kn/base/conf.py +++ b/kn/base/conf.py @@ -1,6 +1,7 @@ from django.conf import settings +from datetime import datetime -DT_MIN = settings.DT_MIN -DT_MAX = settings.DT_MAX +DT_MIN: datetime = settings.DT_MIN +DT_MAX: datetime = settings.DT_MAX # vim: et:sta:bs=2:sw=4: diff --git a/kn/fotos/entities.py b/kn/fotos/entities.py index b6232a6b4..6edeb2714 100644 --- a/kn/fotos/entities.py +++ b/kn/fotos/entities.py @@ -6,6 +6,7 @@ import random import re import subprocess +import typing from collections import namedtuple import PIL.Image @@ -102,6 +103,7 @@ def actual_visibility(visibility): class FotoEntity(SONWrapper): + CACHES:typing.Dict[str, cache_tuple] CACHES = {} def __init__(self, data): diff --git a/kn/leden/api.py b/kn/leden/api.py index 004186cff..8b3d6a495 100644 --- a/kn/leden/api.py +++ b/kn/leden/api.py @@ -74,7 +74,7 @@ def delete_note(data, request): ( << {ok: false, error: "Note not found"} ) """ if 'secretariaat' not in request.user.cached_groups_names: return {'ok': False, 'error': 'Permission denied'} - note = Es.note_by_id(_id(data.get('id'))) + note = Es.Note.by_id(_id(data.get('id'))) if note is None: return {'ok': False, 'error': 'Note not found'} note.delete() diff --git a/kn/leden/date.py b/kn/leden/date.py index b47afcda2..43b09766a 100644 --- a/kn/leden/date.py +++ b/kn/leden/date.py @@ -1,15 +1,15 @@ import datetime -def now(): +def now() -> datetime.datetime: return datetime.datetime.now() -def date_to_dt(d): +def date_to_dt(d: datetime.date) -> datetime.datetime: return datetime.datetime.combine(d, datetime.time()) -def date_to_midnight(d): +def date_to_midnight(d: datetime.date) -> datetime.datetime: return datetime.datetime(d.year, d.month, d.day) # vim: et:sta:bs=2:sw=4: diff --git a/kn/leden/entities.py b/kn/leden/entities.py index 376520f2e..a7d89da91 100644 --- a/kn/leden/entities.py +++ b/kn/leden/entities.py @@ -4,8 +4,11 @@ import hashlib import os import re +import typing +from typing import Optional, Tuple, Set, List, Dict, Iterator, Sequence, Iterable, Any, TypeVar, Union, Callable import PIL.Image +import pymongo.cursor from django.conf import settings from django.contrib.auth.hashers import check_password, make_password @@ -27,6 +30,8 @@ # ###################################################################### ecol = db['entities'] # entities: users, group, tags, studies, ... +EntityID = typing.NewType('EntityID', str) +PermalinkType = Tuple[str, Any, Dict[str, str]] # Example of a user # ---------------------------------------------------------------------- # {"_id" : ObjectId("4e6fcc85e60edf3dc0000270"), @@ -141,18 +146,19 @@ # TODO add example -def get_hexdigest(algorithm, salt, raw_password): +def get_hexdigest(algorithm: str, salt: bytes, raw_password: bytes) -> str: """ Used to check old-style passwords. """ assert algorithm == 'sha1' return hashlib.sha1(salt + raw_password).hexdigest() -def ensure_indices(): +def ensure_indices() -> None: """ Ensures that the indices we need on the collections are set """ # entities # NOTE On some versions of Mongo, a unique sparse index does not allow # more than one entitity without a name. # ecol.ensure_index('names', unique=True, sparse=True) + # XXX: ensure_index is deprecated, use create_index ecol.ensure_index('names', sparse=True) ecol.ensure_index('types') ecol.ensure_index('tags', sparse=True) @@ -185,21 +191,50 @@ class EntityException(Exception): ''' pass - -def entity(d): +AnyEntityType = typing.Union["Group", "User", "Study", "Institute", "Tag", "Brand", "Entity"] +@typing.overload +def entity(d: None) -> None: ... +@typing.overload +def entity(d: Dict[str, typing.Any]) -> AnyEntityType: ... +def entity(d: Optional[Dict[str, typing.Any]]) -> Optional[AnyEntityType]: """ Given a dictionary, returns an Entity object wrapping it """ if d is None: return None - return TYPE_MAP[d['types'][0]](d) - - -def of_type(t): + t = d['types'][0] + if t in TYPE_MAP: + return TYPE_MAP[t](d) + raise TypeError("unknown type", t) + +I = TypeVar('I') +R = TypeVar('R') +class CursorMapper(typing.Generic[I, R]): + """ Wrap a mongo cursor to map the results to an entity type """ + def __init__(self, cursor: pymongo.cursor.Cursor, mapper: typing.Callable[[I], R]) -> None: + self._cursor = cursor + self._mapper = mapper + def __getattr__(self, attr: str) -> Any: + return getattr(self._cursor, attr) + def __getitem__(self, index: Union[int, slice]) -> Union["CursorMapper[I, R]", R]: + if type(index) is slice: + return CursorMapper(self._cursor.__getitem__(index), self._mapper) + else: + return self._mapper(self._cursor.__getitem__(index)) + def __next__(self) -> R: + return self._mapper(self._cursor.__next__()) + def sort(self, *arg: List[Any], **kwarg: Dict[str, Any]) -> "CursorMapper[I, R]": + self._cursor.sort(*arg, **kwarg) + return self + def clone(self) -> "CursorMapper[I, R]": + return CursorMapper(self._cursor.clone(), self._mapper) + def next(self) -> R: + return self._mapper(self._cursor.next()) + +def of_type(t: str) -> Iterator["Entity"]: # AnyEntityType """ Returns all entities of type @t """ - for m in ecol.find({'types': t}): - yield TYPE_MAP[t](m) + return CursorMapper(ecol.find({'types': t}), TYPE_MAP[t]) -def of_type_by_name(t): +def of_type_by_name(t: str) -> Dict[str, AnyEntityType]: """ Returns a `name -> entity' dictionary for the entities of tyoe @t """ ret = {} @@ -218,7 +253,7 @@ def of_type_by_name(t): brands = functools.partial(of_type, 'brand') -def by_ids(ns): +def by_ids(ns: Sequence[EntityID]) -> Dict[EntityID, AnyEntityType]: """ Find entities by a list of _ids """ ret = {} for m in ecol.find({'_id': {'$in': ns}}): @@ -226,10 +261,11 @@ def by_ids(ns): return ret +__id2name_cache: Dict[str, EntityID] __id2name_cache = {} -def id_by_name(n, use_cache=False): +def id_by_name(n: str, use_cache: bool=False) -> Optional[EntityID]: """ Find the _id of entity with name @n """ ret = None if use_cache: @@ -240,26 +276,26 @@ def id_by_name(n, use_cache=False): if obj is None: return None ret = obj['_id'] - if use_cache: + if use_cache and ret is not None: __id2name_cache[n] = ret return ret -def ids_by_names(ns=None, use_cache=False): +def ids_by_names(ns: Optional[List[str]]=None, use_cache:bool=False) -> Dict[str, EntityID]: """ Finds _ids of entities by a list of names """ ret = {} nss = None if ns is None else frozenset(ns) - if use_cache and ns is not None: + if use_cache and nss is not None: nss2 = set(nss) for n in nss: if n in __id2name_cache: ret[n] = __id2name_cache[n] nss2.remove(n) nss = frozenset(nss2) - for m in ecol.find({} if ns is None else + for m in ecol.find({} if nss is None else {'names': {'$in': tuple(nss)}}, {'names': 1}): for n in m.get('names', ()): - if ns is None or n in nss: + if nss is None or n in nss: ret[n] = m['_id'] if use_cache and ns is not None: __id2name_cache[n] = m['_id'] @@ -267,7 +303,7 @@ def ids_by_names(ns=None, use_cache=False): return ret -def by_names(ns): +def by_names(ns: List[str]) -> Dict[str, 'Entity']: """ Finds entities by a list of names """ ret = {} nss = frozenset(ns) @@ -279,31 +315,31 @@ def by_names(ns): return ret -def by_name(n): +def by_name(n: str) -> typing.Optional[AnyEntityType]: """ Finds an entity by name """ return entity(ecol.find_one({'names': n})) -def by_id(n): +def by_id(n: EntityID) -> typing.Optional[AnyEntityType]: """ Finds an entity by id """ if n is None: return None return entity(ecol.find_one({'_id': _id(n)})) -def by_study(study): +def by_study(study: "Study") -> Iterator[AnyEntityType]: """ Finds entities by studies.study """ for m in ecol.find({'studies.study': _id(study)}): yield entity(m) -def by_institute(institute): +def by_institute(institute: "Institute") -> Iterator["Entity"]: """ Finds entities by studies.insitute """ for m in ecol.find({'studies.institute': _id(institute)}): yield entity(m) -def get_years_of_birth(): +def get_years_of_birth() -> range: """ Returns the years of birth. NOTE Currently, simply queries for the minimum and maximum date of @@ -319,7 +355,7 @@ def get_years_of_birth(): return range(start, end + 1) -def by_year_of_birth(year): +def by_year_of_birth(year: int) -> Iterator["Entity"]: """ Finds entities by year of birth """ for m in ecol.find({'types': 'user', 'person.dateOfBirth': { @@ -328,13 +364,13 @@ def by_year_of_birth(year): yield entity(m) -def all(): +def all() -> Iterator[AnyEntityType]: """ Finds all entities """ for m in ecol.find(): yield entity(m) -def names_by_ids(ids=None): +def names_by_ids(ids: Optional[EntityID]=None) -> Dict[EntityID, Optional[str]]: """ Returns an `_id => primary name' dictionary for entities with _id in @ids or all if @ids is None """ ret = {} @@ -347,17 +383,17 @@ def names_by_ids(ids=None): return ret -def ids(): +def ids() -> Set[EntityID]: """ Returns a set of all ids """ - ret = set() + ret: Set[EntityID] = set() for e in ecol.find({}, {'_id': True}): ret.add(e['_id']) return ret -def names(): +def names() -> Set[str]: """ Returns a set of all names """ - ret = set() + ret: Set[str] = set() for e in ecol.find({}, {'names': True}): ret.update(e.get('names', ())) return ret @@ -366,7 +402,7 @@ def names(): # ###################################################################### -def by_keyword(keyword, limit=20, _type=None): +def by_keyword(keyword: str, limit: int=20, _type: Optional[str]=None) -> List["AnyEntityType"]: """ Searches for entities by a keyword. """ # TODO The current method does not use indices. It will search # through every single entity. At the moment, it is fast enough. @@ -379,7 +415,7 @@ def by_keyword(keyword, limit=20, _type=None): # TODO We might want to create an index, for when searching on type too regex = '.*%s.*' % '.*'.join([ re.escape(bit) for bit in keyword.split(' ') if bit]) - query_dict = {'humanNames.human': { + query_dict: Dict[str, typing.Any] = {'humanNames.human': { '$regex': regex, '$options': 'i'}} if _type: query_dict['types'] = _type @@ -390,19 +426,19 @@ def by_keyword(keyword, limit=20, _type=None): # Specialized functions to work with entities. # ###################################################################### - -def bearers_by_tag_id(tag_id, _as=entity): +T = typing.TypeVar('T') +def bearers_by_tag_id(tag_id: EntityID, _as: typing.Callable[..., T]=entity) -> List[T]: """ Find the bearers of the tag with @tag_id """ return [_as(x) for x in ecol.find({'tags': tag_id})] -def year_to_range(year): +def year_to_range(year: int) -> Tuple[datetime.datetime, datetime.datetime]: """ Returns (start_date, end_date) for the given year """ return (datetime.datetime(2003 + year, 9, 1), datetime.datetime(2004 + year, 8, 31)) -def date_to_year(dt): +def date_to_year(dt: datetime.datetime) -> int: """ Returns the `verenigingsjaar' at the date """ year = dt.year - 2004 if dt.month >= 9: @@ -411,8 +447,7 @@ def date_to_year(dt): year = 1 return year - -def quarter_to_range(quarter): +def quarter_to_range(quarter: int) -> Tuple[datetime.datetime, datetime.datetime]: """ Translates a quarter to a start and end datetime. The quarters of the first year are [1,2,3,4]. """ startCMonths = (quarter - 1) * 3 + 8 @@ -427,26 +462,72 @@ def quarter_to_range(quarter): # Functions to work with relations # ###################################################################### +Relation = typing.NewType("Relation", Dict[str, typing.Any]) +class Relation_(object): + from_: Optional[datetime.datetime] + until: Optional[datetime.datetime] + with_: Optional[Union["Entity", EntityID]] + how: Optional[Union["Entity", EntityID]] + who: Union["Entity", EntityID] + + def is_active_at(dt: datetime.datetime) -> bool: + """ Returns whether @rel is active at @dt """ + return ((self.until is None or self.until >= dt) + and (self.from_ is None or self.from_ <= dt)) + + is_virtual: bool + @property + def is_virtual(self) -> bool: + """ Returns whether @rel is "virtual". + + Requires rel['with'] to be deref'd """ + return bool(self.with_.is_group and self.with_.as_group().is_virtual) + + is_active: bool + @property + def is_active(self) -> bool: + return self.is_active_at(now()) + + def may_end(user: "User") -> bool: + """ Returns whether @user may end @rel """ + if self.is_virtual: + return False + if not self.is_active: + return False + if 'secretariaat' in user.cached_groups_names: + return True + if _id(self.who) == _id(user) and \ + self.how is None and \ + by_id(self.with_).has_tag_name('!free-to-join'): + return True + return False -def relation_is_active_at(rel, dt): + def __init__(self, dct: Dict[str, Any]): + self.from_ = dct['from'] + self.until = dct['until'] + self.with_ = dct['with'] + self.how = dct['how'] + self.who = dct['who'] + +def relation_is_active_at(rel: Relation, dt: datetime.datetime) -> bool: """ Returns whether @rel is active at @dt """ return ((rel['until'] is None or rel['until'] >= dt) and (rel['from'] is None or rel['from'] <= dt)) -def relation_is_active(rel): +def relation_is_active(rel: Relation) -> bool: """ Returns whether @rel is active now """ return relation_is_active_at(rel, now()) -def relation_is_virtual(rel): +def relation_is_virtual(rel: Relation) -> bool: """ Returns whether @rel is "virtual". Requires rel['with'] to be deref'd """ - return rel['with'].is_group and rel['with'].as_group().is_virtual + return bool(rel['with'].is_group and rel['with'].as_group().is_virtual) -def user_may_end_relation(user, rel): +def user_may_end_relation(user: "User", rel: Relation) -> bool: """ Returns whether @user may end @rel """ if relation_is_virtual(rel): return False @@ -456,17 +537,17 @@ def user_may_end_relation(user, rel): return True if _id(rel['who']) == _id(user) and \ rel['how'] is None and \ - by_id(rel['with']).has_tag(id_by_name('!free-to-join', True)): + by_id(rel['with']).has_tag_name('!free-to-join'): return True return False -def end_relation(__id): +def end_relation(__id: EntityID) -> None: dt = now() rcol.update({'_id': _id(__id)}, {'$set': {'until': dt}}) -def user_may_begin_relation(user, who, _with, how): +def user_may_begin_relation(user: "EntityID", who: "EntityID", _with: "EntityID", how: Optional[str]) -> bool: """ Returns whether @user may begin a @how-relation between @who and @_with """ _with_e = by_id(_with) @@ -475,34 +556,33 @@ def user_may_begin_relation(user, who, _with, how): return False if 'secretariaat' in user_e.cached_groups_names: return True - if _with_e.has_tag(id_by_name('!free-to-join', True)): + if _with_e.has_tag_name('!free-to-join'): if _id(user) == _id(who) and how is None: return True return False -def add_relation(who, _with, how=None, _from=None, until=None): +def add_relation(who: "Entity", _with: "Entity", how: Optional[str]=None, _from: Optional[datetime.datetime]=None, until: Optional[datetime.datetime]=None) -> Relation: if _from is None: _from = DT_MIN if until is None: until = DT_MAX - return rcol.insert({'who': _id(who), + return Relation(rcol.insert({'who': _id(who), 'with': _id(_with), 'how': None if how is None else _id(how), 'from': _from, - 'until': until}) + 'until': until})) -def user_may_tag(user, group, tag): - return 'secretariaat' in user.cached_groups_names +def user_may_tag(user: "Entity", group: "Entity", tag: "Entity") -> bool: + return bool('secretariaat' in user.cached_groups_names) -def user_may_untag(user, group, tag): - return 'secretariaat' in user.cached_groups_names +def user_may_untag(user: "Entity", group: "Entity", tag: "Entity") -> bool: + return bool('secretariaat' in user.cached_groups_names) -def disj_query_relations(queries, deref_who=False, deref_with=False, - deref_how=False): +def disj_query_relations(queries:Iterable[Any], deref: set) -> Iterable[Relation]: """ Find relations matching any one of @queries. See @query_relations. """ if not queries: @@ -559,13 +639,15 @@ def disj_query_relations(queries, deref_who=False, deref_with=False, # NOTE If bits is a one-element-array the `$or' query does not return # anything, even if it should. Bug in MongoDB? cursor = rcol.find({'$or': bits} if len(bits) != 1 else bits[0]) - if not deref_how and not deref_who and not deref_with: + if deref == set(): return cursor - return __derefence_relations(cursor, deref_who, deref_with, deref_how) + return __derefence_relations(cursor, deref) -def query_relations(who=-1, _with=-1, how=-1, _from=None, until=None, - deref_who=False, deref_with=False, deref_how=False): +RelQuery = typing.Union[int, Sequence[EntityID], EntityID, 'Entity'] +def query_relations(who: RelQuery=-1, _with: RelQuery=-1, how: RelQuery=-1, + _from: Optional[datetime.datetime]=None, until: Optional[datetime.datetime]=None, + deref: set = set()) -> Iterable[Relation]: """ Find matching relations. For each of {who, _with, how}: @@ -576,7 +658,7 @@ def query_relations(who=-1, _with=-1, how=-1, _from=None, until=None, and form an interval. Only relations intersecting this interval are matched. """ - query = {} + query: Dict[str, Any] = {} if who != -1: query['who'] = who if _with != -1: @@ -587,109 +669,93 @@ def query_relations(who=-1, _with=-1, how=-1, _from=None, until=None, query['from'] = _from if until is not None: query['until'] = until - return disj_query_relations([query], deref_who, deref_with, deref_how) + return disj_query_relations([query], deref) -def __derefence_relations(cursor, deref_who, deref_with, deref_how): +def __derefence_relations(cursor: Iterable[Relation], deref: set) -> Iterator[Relation]: # Dereference. First collect the ids of the entities we want to # dereference - e_lut = dict() + e_lut: Dict[EntityID, AnyEntityType] = dict() ids = set() ret = list() for rel in cursor: ret.append(rel) - if deref_with: - ids.add(rel['with']) - if deref_how and rel['how']: - ids.add(rel['how']) - if deref_who: - ids.add(rel['who']) + for what in deref: + if rel[what]: + ids.add(rel[what]) e_lut = by_ids(tuple(ids)) # Dereference! for rel in ret: - if deref_who: - rel['who'] = e_lut[rel['who']] - if deref_how and rel['how']: - rel['how'] = e_lut[rel['how']] - if deref_with: - rel['with'] = e_lut[rel['with']] + for what in deref: + if rel[what]: + rel[what] = e_lut[rel[what]] if rel['from'] == DT_MIN: rel['from'] = None if rel['until'] == DT_MAX: rel['until'] = None - yield rel + yield Relation(rel) -def relation_by_id(__id, deref_who=True, deref_with=True, deref_how=True): +def relation_by_id(__id, deref=set(("who", "how", "what"))): cursor = rcol.find({'_id': _id(__id)}) try: - if not deref_how and not deref_who and not deref_with: + if not deref: return next(cursor) - return next(__derefence_relations(cursor, deref_who, - deref_with, deref_how)) + return next(__derefence_relations(cursor, deref)) except StopIteration: return None - -def entity_humanName(x): +@typing.overload +def entity_humanName(x: None) -> None: ... +@typing.overload +def entity_humanName(x: "Entity") -> str: ... +def entity_humanName(x: Optional["Entity"]) -> Optional[str]: """ Returns the human name of an entity or None if None is given. Useful for the key argument in a sort function. """ return None if x is None else six.text_type(x.humanName) -def dt_until(x): +def dt_until(x: Optional[datetime.datetime]) -> datetime.datetime: """ Treat a datetime from the db as one which is used as an until-date: None is interpreted as DT_MAX. Useful for sorting. """ return x if x else DT_MAX -def dt_from(x): +def dt_from(x: Optional[datetime.datetime]) -> datetime.datetime: """ Treat a datetime from the db as one which is used as an from-date: None is interpreted as DT_MIN. Useful for sorting. """ return x if x else DT_MIN -def relation_until(x): +def relation_until(x: Relation) -> datetime.datetime: """ Returns the datetime until the given relations holds. Useful as `key' argument for sort functions. """ return dt_until(x['until']) -def relation_from(x): +def relation_from(x: Relation) -> datetime.datetime: """ Returns the datetime from which the given relations holds. Useful as `key' argument for sort functions. """ return dt_from(x['from']) -def remove_relation(who, _with, how, _from, until): +def remove_relation(who, _with, how, _from, until) -> None: if _from is None: _from = DT_MIN if until is None: until = DT_MAX + # XXX: .remove is deprecated rcol.remove({'who': _id(who), 'with': _id(_with), 'how': None if how is None else _id(how), 'from': _from, 'until': until}) -# Functions to work with notes -# ###################################################################### - - -def note_by_id(the_id): - tmp = ncol.find_one({'_id': the_id}) - return None if tmp is None else Note(tmp) - - -def get_notes(): - for d in ncol.find({}, sort=[('at', 1)]): - yield Note(d) - # Functions to work with informacie-notifications # ###################################################################### -def notify_informacie(event, user, **props): +def notify_informacie(event: str, user: "User", **props) -> None: data = {'when': now(), 'event': event} data['user'] = _id(user) for key, value in props.items(): @@ -697,7 +763,7 @@ def notify_informacie(event, user, **props): incol.insert(data) -def pop_all_informacie_notifications(): +def pop_all_informacie_notifications() -> typing.List['InformacieNotification']: ntfs = list(incol.find({}, sort=[('when', 1)])) incol.remove({'_id': {'$in': [m['_id'] for m in ntfs]}}) return [InformacieNotification(d) for d in ntfs] @@ -705,21 +771,22 @@ def pop_all_informacie_notifications(): # Models # ###################################################################### - class EntityName(object): """ Wrapper object for a name of an entity """ - def __init__(self, entity, name): + def __init__(self, entity: "Entity", name: str) -> None: self._entity = entity self._name = name + humanNames: Iterator['EntityHumanName'] @property def humanNames(self): - for n in self.entity._data.get('humanNames', ()): - if n['name'] == self.name: + for n in self._entity._data.get('humanNames', ()): + if n['name'] == self._name: yield EntityHumanName(self._entity, n) + primary_humanName: typing.Optional['EntityHumanName'] @property def primary_humanName(self): try: @@ -727,7 +794,12 @@ def primary_humanName(self): except StopIteration: return None - def __str__(self): + name: str + @property + def name(self) -> str: + return self._name + + def __str__(self) -> str: return self._name def __repr__(self): @@ -739,26 +811,31 @@ class EntityHumanName(object): """ Wrapper object for a humanName of an entity """ - def __init__(self, entity, data): + def __init__(self, entity: "Entity", data: Dict[str, str]) -> None: self._entity = entity self._data = data + name: EntityName @property def name(self): return EntityName(self._entity, self._data.get('name')) + humanName: str @property def humanName(self): return self._data['human'] + genitive_prefix: str @property def genitive_prefix(self): return self._data.get('genitive_prefix', 'van de') + genitive: str @property def genitive(self): return self.genitive_prefix + ' ' + six.text_type(self) + definite_article: str @property def definite_article(self): return {'van de': 'de', @@ -778,10 +855,23 @@ class Entity(SONWrapper): """ Base object for every Entity """ - def __init__(self, data): + @classmethod + def all(cls): + if getattr(cls, 'db_typename'): + return of_type(cls.db_typename) + return all() + + @classmethod + def by_name(cls, name): + if getattr(cls, 'db_typename'): + raise NotImplementedError + # return of_type_by_name(cls.db_typename, name) + return by_name(name) + + def __init__(self, data: Dict[str, typing.Any]) -> None: super(Entity, self).__init__(data, ecol) - def is_related_with(self, whom, how=None): + def is_related_with(self, whom: 'Entity', how: Optional[str]=None) -> bool: dt = now() how = None if how is None else _id(how) return rcol.find_one( @@ -791,6 +881,7 @@ def is_related_with(self, whom, how=None): 'until': {'$gte': dt}, 'with': _id(whom)}, {'_id': True}) is not None + cached_groups: typing.List['Group'] @property def cached_groups(self): """ The list of entities this user is None-related with. @@ -800,30 +891,27 @@ def cached_groups(self): dt = now() self._groups_cache = [rel['with'] for rel in self.get_related( - None, dt, dt, False, True, False)] + None, dt, dt, deref=set(("with",)))] return self._groups_cache + cached_groups_names: typing.Set[str] @property def cached_groups_names(self): if not hasattr(self, '_groups_names_cache'): - self._groups_names_cache = set() + self._groups_names_cache: Set[str] = set() for g in self.cached_groups: self._groups_names_cache.update([ str(n) for n in g.names]) return self._groups_names_cache # get reverse-related - def get_rrelated(self, how=-1, _from=None, until=None, deref_who=True, - deref_with=True, deref_how=True): - return query_relations(-1, self, how, _from, until, deref_who, - deref_with, deref_how) + def get_rrelated(self, how=-1, _from: Optional[datetime.datetime]=None, until=None, deref: set = set(("who", "how", "with"))) -> Iterable[Relation]: + return query_relations(-1, self, how, _from, until, deref) - def get_related(self, how=-1, _from=None, until=None, deref_who=True, - deref_with=True, deref_how=True): - return query_relations(self, -1, how, _from, until, deref_who, - deref_with, deref_how) + def get_related(self, how=-1, _from=None, until=None, deref: set = set(("who", "how", "with"))) -> Iterable[Relation]: + return query_relations(self, -1, how, _from, until, deref) - def get_tags(self): + def get_tags(self) -> Iterator['Tag']: for m in ecol.find({'_id': {'$in': self._data.get('tags', ())}} ).sort('humanNames.human', 1): yield Tag(m) @@ -832,22 +920,28 @@ def get_tags(self): def type(self): return self._data['types'][0] + id: EntityID @property def id(self): - return str(self._id) + return EntityID(str(self._id)) + tag_ids: Iterator[EntityID] @property def tag_ids(self): return self._data.get('tags', ()) + tags: Iterator['Tag'] @property - def tags(self): - for m in ecol.find({'_id': { - '$in': self._data.get('tags', ())}}): - yield Tag(m) + def tags(self) -> Iterable['Tag']: + return CursorMapper(ecol.find({'_id': { + '$in': self._data.get('tags', ())}}), Tag) - def has_tag(self, tag): - return _id(tag) in self._data.get('tags', ()) + def has_tag(self, tag: typing.Union["Tag", EntityID]) -> bool: + return bool(_id(tag) in self._data.get('tags', ())) + + def has_tag_name(self, tag: str) -> bool: + t = id_by_name(tag, True) + return self.has_tag(t) if t else False def tag(self, tag, save=True): if self.has_tag(tag): @@ -865,31 +959,37 @@ def untag(self, tag, save=True): if save: self.save() + names: Iterator[EntityName] @property def names(self): for n in self._data.get('names', ()): yield EntityName(self, n) + name: typing.Optional[EntityName] @property def name(self): nms = self._data.get('names', ()) nm = nms[0] if len(nms) >= 1 else None return nm if nm is None else EntityName(self, nm) + description: typing.Optional[str] @property def description(self): return self._data.get('description', None) + other_names: Iterator[EntityName] @property def other_names(self): for n in self._data.get('names', ())[1:]: yield EntityName(self, n) + humanNames: Iterator[EntityHumanName] @property def humanNames(self): for n in self._data.get('humanNames', ()): yield EntityHumanName(self, n) + humanName: typing.Optional[EntityHumanName] @property def humanName(self): try: @@ -898,7 +998,7 @@ def humanName(self): return None @permalink - def get_absolute_url(self): + def get_absolute_url(self) -> PermalinkType: if self.name: return ('entity-by-name', (), {'name': self.name}) @@ -930,19 +1030,19 @@ def is_study(self): return 'study' in self._data['types'] @property def is_institute(self): return 'institute' in self._data['types'] - def as_user(self): return User(self._data) + def as_user(self) -> "User": return User(self._data) - def as_group(self): return Group(self._data) + def as_group(self) -> "Group": return Group(self._data) - def as_brand(self): return Brand(self._data) + def as_brand(self) -> "Brand": return Brand(self._data) - def as_tag(self): return Tag(self._data) + def as_tag(self) -> "Tag": return Tag(self._data) - def as_study(self): return Study(self._data) + def as_study(self) -> "Study": return Study(self._data) - def as_institute(self): return Institute(self._data) + def as_institute(self) -> "Institute": return Institute(self._data) - def as_primary_type(self): + def as_primary_type(self) -> AnyEntityType: return TYPE_MAP[self.type](self._data) def update_address(self, street, number, _zip, city, save=True): @@ -1044,6 +1144,7 @@ def photo_size(self): return width, height + canonical_full_email: typing.Optional[str] @property def canonical_full_email(self): """ Returns the string @@ -1055,6 +1156,7 @@ def canonical_full_email(self): return None return email.utils.formataddr((six.text_type(self.humanName), addr)) + canonical_email: typing.Optional[str] @property def canonical_email(self): if self.type in ('institute', 'study', 'brand', 'tag'): @@ -1062,6 +1164,7 @@ def canonical_email(self): name = str(self.name if self.name else self.id) return "%s@%s" % (name, settings.MAILDOMAIN) + got_mailman_list: bool @property def got_mailman_list(self): if 'use_mailman_list' in self._data: @@ -1072,6 +1175,7 @@ def got_mailman_list(self): return False return True + got_unix_group: bool @property def got_unix_group(self): if 'has_unix_group' in self._data: @@ -1079,7 +1183,7 @@ def got_unix_group(self): else: return True - def add_note(self, what, by=None): + def add_note(self, what: str, by: Optional['Entity'] = None) -> 'Note': dt = now() note = Note({'note': what, 'on': self._id, @@ -1088,7 +1192,7 @@ def add_note(self, what, by=None): note.save() return note - def get_notes(self): + def get_notes(self) -> Iterator['Note']: for d in ncol.find({'on': self._id}, sort=[('at', 1)]): yield Note(d) @@ -1107,6 +1211,7 @@ def __hash__(self): class Group(Entity): + db_typename: typing.ClassVar[str] = "group" @permalink def get_absolute_url(self): @@ -1115,27 +1220,30 @@ def get_absolute_url(self): {'name': self.name}) return ('group-by-id', (), {'_id': self.id}) - def get_current_and_old_members(self): + def get_current_and_old_members(self) -> \ + typing.Tuple[typing.Set[Entity], typing.Set[Entity]]: dt = now() cur, _all = set(), set() - for rel in self.get_rrelated(how=None, deref_with=False): + for rel in self.get_rrelated(how=None, deref=set(("how", "who"))): _all.add(rel['who']) if ((rel['until'] is None or rel['until'] >= dt) and (rel['from'] is None or rel['from'] <= dt)): cur.add(rel['who']) return (cur, _all - cur) - def get_members(self): + def get_members(self) -> typing.List[Entity]: dt = now() return [r['who'] for r in self.get_rrelated( - how=None, _from=dt, until=dt)] + how=None, _from=dt, until=dt, deref=set(("who",)))] + is_virtual: bool @property def is_virtual(self): return 'virtual' in self._data class User(Entity): + db_typename: typing.ClassVar[str] = "user" class _Meta(object): """ Django expects a user object to have a _meta instance. @@ -1161,7 +1269,7 @@ def __init__(self, user): email = son_property(('email',)) pk = son_property(('_id'),) # primary key for Django - def __init__(self, data): + def __init__(self, data: Dict[str, typing.Any]) -> None: super(User, self).__init__(data) self._primary_study = -1 self._meta = User._Meta(self) @@ -1213,20 +1321,23 @@ def check_password(self, pwd): @property def humanName(self): + # XXX: this is a different type return self.full_name def set_humanName(self): raise NotImplemented('setting humanName for users is not implemented') + password: Optional[bytes] @property - def password(self): + def password(self) -> Optional[bytes]: return self._data.get('password', None) + is_active: bool @property - def is_active(self): + def is_active(self) -> bool: return self._data.get('is_active', True) - def is_authenticated(self): + def is_authenticated(self) -> bool: # required by django's auth return True # Required by Django's auth. framework @@ -1240,8 +1351,9 @@ def may_upload_smoel_for(self, user): 'secretariaat' in self.cached_groups_names or \ 'bestuur' in self.cached_groups_names + full_name: str @property - def full_name(self): + def full_name(self) -> str: if ('person' not in self._data or 'family' not in self._data['person'] or 'nick' not in self._data['person']): @@ -1252,16 +1364,19 @@ def full_name(self): + self._data['person']['family'] return self._data['person']['nick'] + bits[1] + ' ' + bits[0] + first_name: str @property - def first_name(self): + def first_name(self) -> str: return self._data.get('person', {}).get('nick') + last_name: str @property - def last_name(self): + def last_name(self) -> str: return self._data.get('person', {}).get('family') + preferred_language: str @property - def preferred_language(self): + def preferred_language(self) -> str: return self._data.get('preferred_language', settings.LANGUAGE_CODE) @property @@ -1289,6 +1404,7 @@ def studies(self): @property def primary_study(self): + # TODO: will crash if study does not exist if self._primary_study == -1: self._primary_study = ( None if not self._data.get('studies', ()) @@ -1343,6 +1459,7 @@ def studentNumber(self): study = self.proper_primary_study return study['number'] if self.proper_primary_study else None + dateOfBirth: datetime.datetime @property def dateOfBirth(self): return self._data.get('person', {}).get('dateOfBirth') @@ -1366,6 +1483,7 @@ def remove_dateOfBirth(self, save=True): if save: self.save() + age: int @property def age(self): # age is a little difficult to calculate because of leap years @@ -1377,12 +1495,14 @@ def age(self): return (today.year - date.year - ((today.month, today.day) < (date.month, date.day))) + is_underage: bool @property def is_underage(self): if self.age is not None: return self.age < 18 return self._data['is_underage'] + got_unix_user: bool @property def got_unix_user(self): if 'has_unix_user' in self._data: @@ -1407,6 +1527,7 @@ def set_locale_on_logon(sender, request, user, **kwargs): class Tag(Entity): + db_typename: typing.ClassVar[str] = "tag" @permalink def get_absolute_url(self): @@ -1415,12 +1536,13 @@ def get_absolute_url(self): {'name': self.name}) return ('tag-by-id', (), {'_id': self.id}) - def get_bearers(self): + def get_bearers(self) -> typing.List[Entity]: return [entity(m) for m in ecol.find({ 'tags': self._id})] class Study(Entity): + db_typename: typing.ClassVar[str] = "study" @permalink def get_absolute_url(self): @@ -1431,6 +1553,7 @@ def get_absolute_url(self): class Institute(Entity): + db_typename: typing.ClassVar[str] = "institute" @permalink def get_absolute_url(self): @@ -1441,60 +1564,80 @@ def get_absolute_url(self): class Brand(Entity): + db_typename: typing.ClassVar[str] = "brand" @permalink - def get_absolute_url(self): + def get_absolute_url(self) -> Tuple[str, Any, Dict[str, EntityID]]: if self.name: return ('brand-by-name', (), {'name': self.name}) return ('brand-by-id', (), {'_id': self.id}) @property - def sofa_suffix(self): + def sofa_suffix(self) -> str: return self._data.get('sofa_suffix', None) class Note(SONWrapper): + """ Notes set on an entity as a todo for the secretary or webcie """ + @classmethod + def by_id(cls, the_id: EntityID) -> Optional["Note"]: + tmp = ncol.find_one({'_id': the_id}) + return None if tmp is None else Note(tmp) + @classmethod + def all(cls) -> Iterable["Note"]: + return CursorMapper(ncol.find({}, sort=[('at', 1)]), Note) + + at: datetime.datetime at = son_property(('at',)) + note: str note = son_property(('note',)) + by_id: EntityID by_id = son_property(('by',)) + on_id: EntityID on_id = son_property(('on',)) - def __init__(self, data): + def __init__(self, data: Dict[str, typing.Any]) -> None: + # todo: typecheck data super(Note, self).__init__(data, ncol) + id: EntityID @property - def id(self): + def id(self) -> str: return str(_id(self)) + on: Entity @property - def on(self): + def on(self) -> Optional[Entity]: return by_id(self._data['on']) + by: Optional[Entity] @property - def by(self): + def by(self) -> Optional[Entity]: return by_id(self._data['by']) + messageId: str @property - def messageId(self): + def messageId(self) -> str: return '' % (self.id, settings.MAILDOMAIN) class InformacieNotification(SONWrapper): - def __init__(self, data): + def __init__(self, data: Dict[str, typing.Any]) -> None: super(InformacieNotification, self).__init__(data, incol) - def user(self): - return by_id(self._data['user']) + def user(self) -> User: + # XXX: typechecker is right to complain + return typing.cast(User, by_id(self._data['user'])) def relation(self): return relation_by_id(self._data['relation']) - def tag(self): - return by_id(self._data['tag']) + def tag(self) -> typing.Optional[Tag]: + return typing.cast(typing.Optional[Tag], by_id(self._data['tag'])) - def entity(self): + def entity(self) -> typing.Optional[Entity]: return by_id(self._data['entity']) def fotoEvent(self): @@ -1506,11 +1649,13 @@ def fotoAlbum(self): return fEs.by_id(self._data['fotoAlbum']) event = son_property(('event', )) + when: datetime.datetime when = son_property(('when', )) # List of type of entities # ###################################################################### +TYPE_MAP: Dict[str, typing.Callable[..., Entity]] TYPE_MAP = { 'group': Group, 'user': User, diff --git a/kn/leden/mongo.py b/kn/leden/mongo.py index e6d2f8f75..5ef1000d6 100644 --- a/kn/leden/mongo.py +++ b/kn/leden/mongo.py @@ -2,6 +2,8 @@ from django.conf import settings from django.utils import six +import typing +from typing import Optional, TypeVar, Any, Union, Mapping try: from pymongo.objectid import ObjectId @@ -17,19 +19,21 @@ class RaceCondition(Exception): pass -def _id(obj): +def _id(obj: typing.Union[ObjectId, str, typing.Any]) -> ObjectId: if isinstance(obj, ObjectId): return obj - if isinstance(obj, six.string_types): + if isinstance(obj, str): return ObjectId(obj) - if hasattr(obj, '_id'): + elif isinstance(obj, six.string_types): + return ObjectId(obj) + elif hasattr(obj, '_id'): return obj._id raise ValueError("Don't know how to turn {!r} into an _id".format(obj)) class SONWrapper(object): - def __init__(self, data, collection, parent=None, detect_race=False): + def __init__(self, data: Mapping[str, Any], collection: Any, parent: Optional["SONWrapper"]=None, detect_race: bool=False) -> None: ''' parent: SONWrapper can be nested. This is the parent. detect_race: Add a _version field which is incremented with each @@ -40,15 +44,16 @@ def __init__(self, data, collection, parent=None, detect_race=False): self._parent = parent self._detect_race = detect_race - def delete(self): + def delete(self) -> None: assert self._data['_id'] is not None # TODO check version + # XXX: remove is deprecated self._collection.remove({ '_id': self._data['_id']}) # We take the keyword argument update_fields to be compatible with # Django's Model.save. However, we do not use it, yet. - def save(self, update_fields=NotImplemented): + def save(self, update_fields: None=NotImplemented) -> None: if self._parent is None: if '_id' in self._data: if self._detect_race: @@ -73,33 +78,33 @@ def save(self, update_fields=NotImplemented): self._parent.save() @property - def _id(self): + def _id(self) -> ObjectId: if self._parent is None: return self._data['_id'] return self._parent._id @property - def _version(self): + def _version(self) -> int: if self._parent is None: - return self._data['_version'] - return self._parent._version + return int(self._data['_version']) + return int(self._parent._version) - def __repr__(self): + def __repr__(self) -> str: return "" % self._id - -def son_property(path, default=None): +T = typing.TypeVar('T') +def son_property(path: typing.Sequence[str], default: Optional[T]=None) -> property: """ A convenience shortcut to create properties on SONWrapper subclasses. Will return a getter/setter property that gets/sets self._data[path[0]]...[path[-1]] verbatim. """ - def __getter(self): + def __getter(self: SONWrapper) -> Optional[T]: obj = self._data for bit in path[:-1]: obj = obj.get(bit, {}) return obj.get(path[-1], default) - def __setter(self, x): + def __setter(self: SONWrapper, x: T) -> None: obj = self._data for bit in path[:-1]: if bit not in obj: diff --git a/kn/leden/views.py b/kn/leden/views.py index 6bba5b968..8bbdd11a0 100644 --- a/kn/leden/views.py +++ b/kn/leden/views.py @@ -38,14 +38,14 @@ @login_required def user_list(request, page): - pr = Paginator(Es.ecol.find({'types': 'user'}).sort( + pr = Paginator(Es.User.all().sort( 'humanNames.human', 1), 20) try: p = pr.page(1 if page is None else page) except EmptyPage: raise Http404 return render(request, 'leden/user_list.html', - {'users': [Es.User(m) for m in p.object_list], + {'users': p.object_list, 'page_obj': p, 'paginator': pr}) @@ -178,7 +178,7 @@ def _user_detail(request, user): def _group_detail(request, group): ctx = _entity_detail(request, group) - isFreeToJoin = group.has_tag(Es.id_by_name('!free-to-join', True)) + isFreeToJoin = group.has_tag_name('!free-to-join') rel_id = None if isFreeToJoin: dt = now() @@ -187,7 +187,7 @@ def _group_detail(request, group): assert len(rel) <= 1 for r in rel: rel_id = r['_id'] - ctx.update({'isFreeToJoin': group.has_tag(Es.by_name('!free-to-join')), + ctx.update({'isFreeToJoin': isFreeToJoin, 'request': request, 'relation_with_group': rel_id}) return render(request, 'leden/group_detail.html', ctx) @@ -269,7 +269,7 @@ def years_of_birth(request): @login_required def users_underage(request): - users = Es.users() + users = Es.User.all() users = filter(lambda u: u.is_active, users) users = filter(lambda u: u.is_underage, users) users = sorted(users, key=lambda x: x.dateOfBirth) @@ -467,7 +467,7 @@ def secr_notes(request): if 'secretariaat' not in request.user.cached_groups_names: raise PermissionDenied return render(request, 'leden/secr_notes.html', - {'notes': Es.get_notes()}) + {'notes': Es.Note.all()}) @login_required diff --git a/kn/utils/giedo/db.py b/kn/utils/giedo/db.py index 9c7a32761..74abda130 100644 --- a/kn/utils/giedo/db.py +++ b/kn/utils/giedo/db.py @@ -189,7 +189,7 @@ def relkey(rel): # Set is_active on Users if and only if they are not in the `leden' group. # TODO We might optimize this by including it in a more generic process active_users = [rel['who'] for rel in Es.by_name('leden').get_rrelated( - None, dt_now, dt_now, False, False, False)] + None, dt_now, dt_now, deref=set())] for u in Es.users(): is_active = u._id in active_users if u.is_active == is_active: diff --git a/utils/leden/quarterlyMembership.py b/utils/leden/quarterlyMembership.py index eae8dcee9..092a666a9 100644 --- a/utils/leden/quarterlyMembership.py +++ b/utils/leden/quarterlyMembership.py @@ -21,8 +21,7 @@ def main(): for q in range(1, max_q + 1): start, end = Es.quarter_to_range(q) for m in leden.get_rrelated(_from=start, until=end, how=None, - deref_who=False, deref_with=False, - deref_how=False): + deref=set()): lut[id2name[m['who']]].add(q) for i, name in enumerate(sorted(six.itervalues(id2name))): if i % 20 == 0: