1

I am writing unittest for my Flask web app using Selenium using Flask-Testing's LiveServer.

When running my test, I want to have one browser for all the tests instead of opening a new browser instance for each test, so I am using unittest's setUpClass.

class TestApp(LiveServerTestCase, unittest.TestCase):
    def create_app(self):
        app = create_app()
        app.config['TESTING'] = True
        app.config.update(LIVESERVER_PORT=9898)
        return app

    @classmethod
    def setUpClass(cls):
        cls.chrome_browser = webdriver.Chrome()
        cls.chrome_browser.get(cls.get_server_url())

    def test_main_page(self):
        self.assertEqual(1, 1)

When running my test, I am getting the following:

TypeError: get_server_url() missing 1 required positional argument: 'self'

How can I set up the browser in setUpClass?

davidism
  • 121,510
  • 29
  • 395
  • 339
AcroTom
  • 71
  • 6

1 Answers1

0

You have to use Flask-Testing as it designed by it's developers - start Flask server and Selenium driver on __call__ method run ...

OR you can override logic (in this case you have selenium driver creating on setUpClass and brand new Flask server on each test run)

import multiprocessing
import socket
import socketserver
import time
from urllib.parse import urlparse, urljoin

from flask import Flask
from flask_testing import LiveServerTestCase
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.wait import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC


class MyTest(LiveServerTestCase):

    @classmethod
    def setUpClass(cls) -> None:
        # Get the app
        cls.app = cls.create_app()

        cls._configured_port = cls.app.config.get('LIVESERVER_PORT', 5000)
        cls._port_value = multiprocessing.Value('i', cls._configured_port)

        # We need to create a context in order for extensions to catch up
        cls._ctx = cls.app.test_request_context()
        cls._ctx.push()

        cls.driver = webdriver.Firefox()

    @classmethod
    def tearDownClass(cls) -> None:
        cls._post_teardown()

    @classmethod
    def get_server_url(cls):
        """
        Return the url of the test server
        """
        return 'http://localhost:%s' % cls._port_value.value

    @classmethod
    def create_app(cls):
        app = Flask(__name__)

        @app.route('/')
        def hello_world():
            return 'Hello, World!'

        app.config['TESTING'] = True
        app.config['DEBUG'] = True
        app.config['ENV'] = "development"
        # Default port is 5000
        app.config['LIVESERVER_PORT'] = 8943
        # Default timeout is 5 seconds
        app.config['LIVESERVER_TIMEOUT'] = 10
        return app

    def __call__(self, *args, **kwargs):
        """
                Does the required setup, doing it here means you don't have to
                call super.setUp in subclasses.
                """
        try:
            self._spawn_live_server()
            super(LiveServerTestCase, self).__call__(*args, **kwargs)
        finally:
            self._terminate_live_server()

    @classmethod
    def _post_teardown(cls):
        if getattr(cls, '_ctx', None) is not None:
            cls._ctx.pop()
            del cls._ctx
        cls.driver.quit()

    @classmethod
    def _terminate_live_server(cls):
        if cls._process:
            cls._process.terminate()

    @classmethod
    def _spawn_live_server(cls):
        cls._process = None
        port_value = cls._port_value

        def worker(app, port):
            # Based on solution: http://stackoverflow.com/a/27598916
            # Monkey-patch the server_bind so we can determine the port bound by Flask.
            # This handles the case where the port specified is `0`, which means that
            # the OS chooses the port. This is the only known way (currently) of getting
            # the port out of Flask once we call `run`.
            original_socket_bind = socketserver.TCPServer.server_bind

            def socket_bind_wrapper(self):
                ret = original_socket_bind(self)

                # Get the port and save it into the port_value, so the parent process
                # can read it.
                (_, port) = self.socket.getsockname()
                port_value.value = port
                socketserver.TCPServer.server_bind = original_socket_bind
                return ret

            socketserver.TCPServer.server_bind = socket_bind_wrapper
            app.run(port=port, use_reloader=False)

        cls._process = multiprocessing.Process(
            target=worker, args=(cls.app, cls._configured_port)
        )

        cls._process.start()

        # We must wait for the server to start listening, but give up
        # after a specified maximum timeout
        timeout = cls.app.config.get('LIVESERVER_TIMEOUT', 5)
        start_time = time.time()

        while True:
            elapsed_time = (time.time() - start_time)
            if elapsed_time > timeout:
                raise RuntimeError(
                    "Failed to start the server after %d seconds. " % timeout
                )

            if cls._can_ping_server():
                break

    @classmethod
    def _get_server_address(cls):
        """
        Gets the server address used to test the connection with a socket.
        Respects both the LIVESERVER_PORT config value and overriding
        get_server_url()
        """
        parts = urlparse(cls.get_server_url())

        host = parts.hostname
        port = parts.port

        if port is None:
            if parts.scheme == 'http':
                port = 80
            elif parts.scheme == 'https':
                port = 443
            else:
                raise RuntimeError(
                    "Unsupported server url scheme: %s" % parts.scheme
                )

        return host, port

    @classmethod
    def _can_ping_server(cls):
        host, port = cls._get_server_address()
        if port == 0:
            # Port specified by the user was 0, and the OS has not yet assigned
            # the proper port.
            return False

        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            sock.connect((host, port))
        except socket.error as e:
            success = False
        else:
            success = True
        finally:
            sock.close()

        return success

    def test_main_page(self):
        self.driver.get(self.get_server_url())
        body = WebDriverWait(self.driver, 10).until(
            EC.visibility_of_element_located((By.TAG_NAME, "body"))
        )
        self.assertEqual(body.text, "Hello, World!")

    def test_main_page_once_more_time(self):
        self.driver.get(urljoin(self.get_server_url(), "some/wrong/path"))
        body = WebDriverWait(self.driver, 10).until(
            EC.visibility_of_element_located((By.TAG_NAME, "body"))
        )
        self.assertTrue(body.text.startswith("Not Found"))

madzohan
  • 11,488
  • 9
  • 40
  • 67
  • Thank you! I am still not sure though how to implement these changes you suggested – AcroTom Nov 14 '20 at 21:05
  • You're welcome :) What do you mean by "how to implement these changes"? it is implemented and tested :) you only need to change Firefox to Chrome and run unittests ... if you have other troubles you could ask here in comments or in another question – madzohan Nov 15 '20 at 00:02
  • I am not sure how I should tweak my existing code with the snippet you provided you it will work as I expect it to. You also mentioned using the __call__ and I couldn't find any reference for it online. – AcroTom Nov 15 '20 at 10:43
  • you don't need to find references, simply go to the source of LiveServerTestCase class and look what happens there and override what you need if smth doesn't fit your needs (what I actually done, while writing this snippet) ... Also you could use some GUI IDE debugger like PyCharm, set breakpoints etc it is pretty easy though – madzohan Nov 15 '20 at 11:20