diff --git a/waffle/middleware.py b/waffle/middleware.py index ec31b1a7..c7859d0c 100644 --- a/waffle/middleware.py +++ b/waffle/middleware.py @@ -1,11 +1,28 @@ from django.http import HttpRequest, HttpResponse from django.utils.deprecation import MiddlewareMixin from django.utils.encoding import smart_str - from waffle.utils import get_setting +from waffle import get_waffle_flag_model + +WaffleFlag = get_waffle_flag_model() class WaffleMiddleware(MiddlewareMixin): + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + """Ensure testing cookie are always set, even if Waffle isn't used""" + for flag in WaffleFlag.objects.filter(testing=True): + tc = get_setting("TEST_COOKIE") % flag.name + if tc in request.GET: + on = request.GET[tc] == "1" + if not hasattr(request, "waffle_tests"): + request.waffle_tests = {} + request.waffle_tests[flag.name] = on + return self.process_response(request, self.get_response(request)) + def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: secure = get_setting('SECURE') max_age = get_setting('MAX_AGE') diff --git a/waffle/tests/test_waffle.py b/waffle/tests/test_waffle.py index 87f3e69d..58123edc 100644 --- a/waffle/tests/test_waffle.py +++ b/waffle/tests/test_waffle.py @@ -374,19 +374,16 @@ def test_testing_flag_header(self): def test_set_then_unset_testing_flag(self): waffle.get_waffle_flag_model().objects.create(name='myflag', testing=True) - response = self.client.get('/flag_in_view?dwft_myflag=1') - self.assertEqual(b'on', response.content) - + self.client.get('/foo?dwft_myflag=1') response = self.client.get('/flag_in_view') self.assertEqual(b'on', response.content) - response = self.client.get('/flag_in_view?dwft_myflag=0') - self.assertEqual(b'off', response.content) - + self.client.get('/foo?dwft_myflag=0') response = self.client.get('/flag_in_view') self.assertEqual(b'off', response.content) - response = self.client.get('/flag_in_view?dwft_myflag=1') + self.client.get('/foo?dwft_myflag=1') + response = self.client.get('/flag_in_view') self.assertEqual(b'on', response.content) @override_settings(DATABASE_ROUTERS=['waffle.tests.base.ReplicationRouter'])