From a97d9568de89a27e43305547f1036cc4a65d19e4 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 16 Jan 2019 15:22:13 +0100 Subject: [PATCH] Add optimized ToOne sqlalchemy aware field. This version of fields.ToOne() is able to read the local foreign key value to serialize the reference of the remote object. It can be useful under some circumstances to avoid performing additional db lookups. --- flask_potion/contrib/alchemy/fields.py | 29 +++++++++++- flask_potion/natural_keys.py | 14 ++++-- tests/__init__.py | 53 +++++++++++++++++++++- tests/contrib/alchemy/test_fields.py | 62 ++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 tests/contrib/alchemy/test_fields.py diff --git a/flask_potion/contrib/alchemy/fields.py b/flask_potion/contrib/alchemy/fields.py index 1e02955..d7de7db 100644 --- a/flask_potion/contrib/alchemy/fields.py +++ b/flask_potion/contrib/alchemy/fields.py @@ -1,4 +1,7 @@ -from flask_potion.fields import Object +from werkzeug.utils import cached_property + +from flask_potion.fields import Object, ToOne as GenericToOne +from flask_potion.utils import get_value, route_from class InlineModel(Object): @@ -15,3 +18,27 @@ def converter(self, instance): if instance is not None: instance = self.model(**instance) return instance + + +class ToOne(GenericToOne): + """ + Same as flask_potion.fields.ToOne + except it will use the local id stored on the ForeignKey field to serialize the field. + This is an optimisation to avoid additional lookups to the database, + in order to prevent fetching the remote object, just to obtain its id, + that we already have. + Limitations: + - It works only if the foreign key is made of a single field. + - It works only if the serialization is using the ForeignKey as source of information to Identify the remote resource. + - `attribute` parameter is ignored. + """ + def output(self, key, obj): + column = getattr(obj.__class__, key) + local_columns = column.property.local_columns + assert len(local_columns) == 1 + local_column = list(local_columns)[0] + key = local_column.key + return self.format(get_value(key, obj, self.default)) + + def formatter(self, item): + return self.formatter_key.format(item, is_local=True) diff --git a/flask_potion/natural_keys.py b/flask_potion/natural_keys.py index 1556b03..d7e26cc 100644 --- a/flask_potion/natural_keys.py +++ b/flask_potion/natural_keys.py @@ -11,6 +11,7 @@ class Key(Schema, ResourceBound): + is_local = False def matcher_type(self): type_ = self.response['type'] @@ -43,11 +44,16 @@ def schema(self): "additionalProperties": False } + def _id_uri(self, resource, id_): + return '{}/{}'.format(resource.route_prefix, id_) + def _item_uri(self, resource, item): # return url_for('{}.instance'.format(self.resource.meta.id_attribute, item, None), get_value(self.resource.meta.id_attribute, item, None)) return '{}/{}'.format(resource.route_prefix, get_value(resource.manager.id_attribute, item, None)) - def format(self, item): + def format(self, item, is_local=False): + if is_local: + return {'$ref': self._id_uri(self.resource, item)} return {"$ref": self._item_uri(self.resource, item)} def convert(self, value): @@ -71,7 +77,7 @@ def rebind(self, resource): def schema(self): return self.resource.schema.fields[self.property].request - def format(self, item): + def format(self, item, is_local=False): return self.resource.schema.fields[self.property].output(self.property, item) @cached_property @@ -101,7 +107,7 @@ def schema(self): "additionalItems": False } - def format(self, item): + def format(self, item, is_local=False): return [self.resource.schema.fields[p].output(p, item) for p in self.properties] @cached_property @@ -123,7 +129,7 @@ def _on_bind(self, resource): def schema(self): return self.id_field.request - def format(self, item): + def format(self, item, is_local=False): return self.id_field.output(self.resource.manager.id_attribute, item) def convert(self, value): diff --git a/tests/__init__.py b/tests/__init__.py index fa27a4d..d08873f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,9 @@ +from pprint import pformat + from flask import json, Flask from flask.testing import FlaskClient from flask_testing import TestCase +import sqlalchemy class ApiClient(FlaskClient): @@ -49,4 +52,52 @@ def create_app(self): return app def pp(self, obj): - print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': '))) \ No newline at end of file + print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': '))) + + +class DBQueryCounter: + """ + Use as a context manager to count the number of execute()'s performed + against the given sqlalchemy connection. + + Usage: + with DBQueryCounter(db.session) as ctr: + db.session.execute("SELECT 1") + db.session.execute("SELECT 1") + ctr.assert_count(2) + """ + + def __init__(self, session, reset=True): + self.session = session + self.reset = reset + self.statements = [] + + def __enter__(self): + if self.reset: + self.session.expire_all() + sqlalchemy.event.listen( + self.session.get_bind(), 'after_execute', self._callback + ) + return self + + def __exit__(self, *_): + sqlalchemy.event.remove( + self.session.get_bind(), 'after_execute', self._callback + ) + + def get_count(self): + return len(self.statements) + + def _callback(self, conn, clause_element, multiparams, params, result): + self.statements.append((clause_element, multiparams, params)) + + def display_all(self): + for clause, multiparams, params in self.statements: + print(pformat(str(clause)), multiparams, params) + print('\n') + count = self.get_count() + return 'Counted: {count}'.format(count=count) + + def assert_count(self, expected): + count = self.get_count() + assert count == expected, self.display_all() diff --git a/tests/contrib/alchemy/test_fields.py b/tests/contrib/alchemy/test_fields.py new file mode 100644 index 0000000..af54b4d --- /dev/null +++ b/tests/contrib/alchemy/test_fields.py @@ -0,0 +1,62 @@ +from flask_sqlalchemy import SQLAlchemy + +from flask_potion import Api, fields +from flask_potion.resource import ModelResource +from flask_potion.contrib.alchemy.fields import ToOne as SAToOne +from tests import BaseTestCase, DBQueryCounter + + +class SQLAlchemyToOneRemainNoPrefetchTestCase(BaseTestCase): + """ + """ + + def setUp(self): + super(SQLAlchemyToOneRemainNoPrefetchTestCase, self).setUp() + self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + self.api = Api(self.app) + self.sa = sa = SQLAlchemy( + self.app, session_options={"autoflush": False}) + + class Type(sa.Model): + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(60), nullable=False) + + class Machine(sa.Model): + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(60), nullable=False) + + type_id = sa.Column(sa.Integer, sa.ForeignKey(Type.id)) + type = sa.relationship(Type, foreign_keys=[type_id]) + + sa.create_all() + + class MachineResource(ModelResource): + class Meta: + model = Machine + + class Schema: + type = SAToOne('type') + + class TypeResource(ModelResource): + class Meta: + model = Type + + self.MachineResource = MachineResource + self.TypeResource = TypeResource + + self.api.add_resource(MachineResource) + self.api.add_resource(TypeResource) + + def test_relation_serialization_does_not_load_remote_object(self): + response = self.client.post('/type', data={"name": "aaa"}) + aaa_uri = response.json["$uri"] + self.client.post( + '/machine', data={"name": "foo", "type": {"$ref": aaa_uri}}) + with DBQueryCounter(self.sa.session) as counter: + response = self.client.get('/machine') + self.assert200(response) + self.assertJSONEqual( + [{'$uri': '/machine/1', 'type': {'$ref': aaa_uri}, 'name': 'foo'}], + response.json) + counter.assert_count(1) +