2

I'm upgrading existing middleware to a new style 1.10+ Django middleware.

Previously it was one similar to this:

class ThreadLocalMiddleware(MiddlewareMixin):
    """ Simple middleware that adds the request object in thread local storage."""

    _thread_locals = local()

    def process_request(self, request):
        _thread_locals.request = request

    def process_response(self, request, response):
        if hasattr(_thread_locals, 'request'):
            del _thread_locals.request
        return response

    def process_exception(self, request, exception):
        if hasattr(_thread_locals, 'request'):
            del _thread_locals.request

after rewrite to the new style:

class ThreadLocalMiddleware:

    _thread_locals = local()

    def __init__(self, get_response=None):
        self.get_response = get_response

    def __call__(self, request):
        _thread_locals.request = request

        response = self.get_response(request)

        if hasattr(_thread_locals, 'request'):
            del _thread_locals.request
        return response

    def process_exception(self, request, exception):
        if hasattr(_thread_locals, 'request'):
            del _thread_locals.request

My question is how to unit test that the middleware sets the request in the _thread_locals? Previously it was as easy as calling the process_request method in my unit tests. Now I have only the __call__ method which will erase the variable at the end.

Looking at the django middleware test, they still use the backward compatible class allowing them to keep the old tests, I wonder how will they test it after it's removed.

1 Answers1

2

In the end I've used a test similar to this. The get_response_callback gets called after when the pre-request work of the middleware is done.

def test_middleware():

    def get_response_callback(req):
        assert ThreadLocalMiddleware._thread_locals.request

    middleware = ThreadLocalMiddleware(get_response_callback)
    request = RequestFactory().get('/')
    request.user = AnonymousUser()

    middleware(request)
    assert hasattr(ThreadLocalMiddleware._thread_locals, 'request') == False