0

In the class I am testing, I want to mock the whole DataAccess class that is used as a member variable. The DataAccess class is just abstracting the SQLite database connection.

I have created a replacement MockDataAccess class that connects to a test database, but the main database still seems to be called - what am I doing wrong?

Edit: have updated where I am patching as recommended and originally did, but still isn't working?

Class to test:

from .data_access import DataAccess


class Query:
    def __init__(self):
        self._data_access = DataAccess()

    def get_all_data(self):
        queryset = self._data_access.execute('''
            SELECT
                Purchase_Product.purchase_id,
                Product.product_name,
                Purchase_Product.quantity,
                Customer.first_name,
                Customer.last_name,
                Customer.email,
                Address.line_one,
                Address.line_two,
                Address.city,
                Address.postcode,
                Status.status_description,
                Purchase.created_date,
                Purchase.dispatched_date,
                Purchase.completed_date,
                Postage.postage_description,
                Product.individual_price,
                Product.id,
                Product.aisle,
                Product.shelf
            FROM
                Product
                INNER JOIN Purchase_Product ON Purchase_Product.product_id = Product.id
                INNER JOIN Purchase ON Purchase.id = Purchase_Product.purchase_id
                INNER JOIN Status ON Status.id = Purchase.status_id
                INNER JOIN Postage ON Postage.id = Purchase.postage_id
                INNER JOIN Customer ON Customer.id = Purchase.customer_id
                INNER JOIN Customer_Address ON Customer_Address.customer_id = Customer.id
                INNER JOIN Address ON Address.id = Customer_Address.address_id
            ORDER BY
                Status.id
        ''', None)
        return queryset.fetchall()


My test class:

import os
import sqlite3
from sqlite3.dbapi2 import Connection, Cursor
import pytest
import mock
from src.main.order_management.data.query import Query

class MockDataAccess:
    def __init__(self):
        self._conn = self._create_connection()
        self._cur = self._conn.cursor()
    
    def _create_connection(self):
        THIS_DIR = os.path.dirname(__file__)
        TEST_DATABASE = os.path.join(THIS_DIR, 'database', 'TestOnlineStore.db')
        return sqlite3.connect(TEST_DATABASE)

    def execute(self, query, data):
        if data is None:
            self._cur.execute(query)
        else:
            self._cur.execute(query, data)
        self._conn.commit()
        return self._cur

@pytest.fixture
def data_access():
    return MockDataAccess()

@pytest.fixture
def query():
    return Query()

@pytest.fixture
def setup_database(data_access):
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Address (
            id integer PRIMARY KEY AUTOINCREMENT,
            line_one text NOT NULL,
            line_two text,
            city text NOT NULL,
            postcode text
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Customer (
            id integer PRIMARY KEY AUTOINCREMENT,
            first_name text NOT NULL,
            last_name text NOT NULL,
            email text
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Customer_Address (
            customer_id integer,
            address_id integer,

            FOREIGN KEY (customer_id)
                REFERENCES Customer (id)
                    ON DELETE CASCADE
            FOREIGN KEY (address_id)
                REFERENCES Address (id)
                    ON DELETE CASCADE
            PRIMARY KEY (customer_id, address_id)
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Platform (
            id integer PRIMARY KEY AUTOINCREMENT,
            platform_name integer NOT NULL,
            user_token text NOT NULL
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Status (
            id integer PRIMARY KEY AUTOINCREMENT,
            status_description text NOT NULL
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Postage (
            id integer PRIMARY KEY AUTOINCREMENT,
            postage_description text NOT NULL
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Purchase (
            id integer PRIMARY KEY AUTOINCREMENT,
            platform_id integer,
            customer_id integer,
            status_id integer,
            postage_id integer,
            created_date text NOT NULL,
            dispatched_date text,
            completed_date text,
            
            FOREIGN KEY (platform_id)
                REFERENCES Platform (id)
                    ON DELETE CASCADE
            FOREIGN KEY (customer_id)
                REFERENCES Customer (id)
                    ON DELETE CASCADE
            FOREIGN KEY (status_id)
                REFERENCES Status (id)
                    ON DELETE CASCADE
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Product (
            id integer PRIMARY KEY AUTOINCREMENT,
            product_name integer,
            product_description integer,
            individual_price real,
            stock_count integer,
            aisle integer,
            shelf integer
    );''', None)
    data_access.execute('''
        CREATE TABLE IF NOT EXISTS Purchase_Product (
            purchase_id integer,
            product_id integer,
            quantity integer,

            FOREIGN KEY (purchase_id)
                REFERENCES Purchase (id)
                    ON DELETE CASCADE
            FOREIGN KEY (product_id)
                REFERENCES Product (id)
                    ON DELETE CASCADE
            PRIMARY KEY (purchase_id, product_id)
    );''', None)
    
@pytest.fixture
def clean_database(data_access):
    data_access.execute('''DROP TABLE IF EXISTS Address''', None)
    data_access.execute('''DROP TABLE IF EXISTS Customer''', None)
    data_access.execute('''DROP TABLE IF EXISTS Customer_Address''', None)
    data_access.execute('''DROP TABLE IF EXISTS Platform''', None)
    data_access.execute('''DROP TABLE IF EXISTS Postage''', None)
    data_access.execute('''DROP TABLE IF EXISTS Product''', None)
    data_access.execute('''DROP TABLE IF EXISTS Purchase''', None)
    data_access.execute('''DROP TABLE IF EXISTS Purchase_Product''', None)
    data_access.execute('''DROP TABLE IF EXISTS Status''', None)
    

@pytest.fixture
def setup_test_data1(setup_database, data_access):
    data_access.execute('''
        INSERT INTO Address (line_one, line_two, city) 
            VALUES('Test Line One', 'Test Line Two', 'Test City')
    ''', None)
    data_access.execute('''
        INSERT INTO Customer (first_name, last_name, email) 
            VALUES('Test First Name', 'Test Last Name', 'Test Email')
    ''', None)
    data_access.execute('''
        INSERT INTO Customer_Address (customer_id, address_id)
            VALUES(1, 1)
    ''', None)
    data_access.execute('''
        INSERT INTO Postage (postage_description)
            VALUES('1st Class')
    ''', None)
    data_access.execute('''
        INSERT INTO Product (product_name, individual_price, stock_count, aisle, shelf)
            VALUES('Test Product', 100.00, 2, 3, 4)
    ''', None)
    data_access.execute('''
        INSERT INTO Status (status_description)
            VALUES('Awaiting')
    ''', None)
    data_access.execute('''
        INSERT INTO Purchase (customer_id, status_id, postage_id, created_date)
            VALUES(1, 1, 1, 'Date Now')
    ''', None)
    data_access.execute('''
        INSERT INTO Purchase_Product (purchase_id, product_id, quantity)
            VALUES(1, 1, 1)
    ''', None)

@mock.patch('src.main.order_management.data.query.DataAccess', new_callable=MockDataAccess)
def test_get_all_data(mock_data_access, query, clean_database, setup_database, setup_test_data1, data_access):
    all_data = query.get_all_data()
    assert all_data == ("jdi", "fmn")
Charlie Clarke
  • 177
  • 1
  • 9
  • You are mocking the wrong object, see [where to patch](https://docs.python.org/3/library/unittest.mock.html#id6). You have to mock the `DataAccess` object used in your class, e.g. something like `@mock.patch('src.main.order_management.data.query.DataAccess'`. – MrBean Bremen Mar 21 '21 at 14:49
  • I originally had patched it like this, but in either way, the main database is still being called? – Charlie Clarke Mar 21 '21 at 16:03
  • Have you checked that `self._data_access` is indeed a mock if you mock it the way you first did? – MrBean Bremen Mar 21 '21 at 16:38
  • It isn't a mock object - it is still referencing the main DataAccess class, that's what I don't understand – Charlie Clarke Mar 21 '21 at 16:58
  • I just used your code (with the correct patch), and it works as expected (e.g. tried to connect to the test database as implemented in your mock). Maybe you can adapt your question with the correct patching you tried before to see what you actually tried, as the current patch is obviously incorrect. – MrBean Bremen Mar 21 '21 at 17:48
  • So what is the right way to patch? @mock.patch('src.main.order_management.data.query.DataAccess', new_callable=MockDataAccess) or @mock.patch('src.main.order_management.data.data_access.DataAccess', new_callable=MockDataAccess) – Charlie Clarke Mar 21 '21 at 19:01
  • As I wrote in the first comment - the first one. – MrBean Bremen Mar 21 '21 at 20:07
  • @MrBeanBremen I'm at a bit of a loss with this - it's the last thing I need to test also. Did you have any other suggestions for how I could resolve this? – Charlie Clarke Mar 23 '21 at 09:14
  • Sorry, I don't know - I have tested it locally (your code with only the module paths adapted, and a dummy DataAccess implementation), and it worked fine for me. – MrBean Bremen Mar 23 '21 at 18:38
  • Thanks for your efforts - I did also try this using vanilla unittest patching and had the same result – Charlie Clarke Mar 23 '21 at 20:04

2 Answers2

0

I found that my problem was that it wasn't patching data_access, because data_access had already been declared when I created my test instance using a pytest fixture.

Also, I found that new_callable wasn't in fact behaving as I thought it would, so I have used return_value instead and passed an instance of MockDataAccess. Now my test database is being called as expected.

New test_query.py (only bits changed):


mock_data_access = MockDataAccess()

@mock.patch('src.main.order_management.data.query.DataAccess', return_value=mock_data_access)
def test_get_all_data(mock_data_access, clean_database,
                      setup_database, setup_test_data1):
    query = Query()
    all_data = query.get_all_data()
    assert all_data == [
                        (1,
                         'Test Product',
                         1,
                         'Test First Name',
                         'Test Last Name',
                         'Test Email',
                         'Test Line One',
                         'Test Line Two',
                         'Test City',
                         None,
                         'Awaiting',
                         'Date Now',
                         None,
                         None,
                         '1st Class',
                         100.0,
                         1,
                         3,
                         4)
                       ]

Charlie Clarke
  • 177
  • 1
  • 9
0

I have also thought that it is cleaner and more logical to be creating my MockDataAccess using inheritance.

So, rather than copy and pasting my MockDataAccess and changing the database it is pointing to, I now create a child class of this and set the parent class' (DataAccess) connection attribute.

Main DataAccess class:

import sqlite3
from sqlite3.dbapi2 import Connection, Cursor
from ..config import DATABASE


class DataAccess:
    def __init__(self):
        self._connection = self._create_connection()
        self._cursor = self._connection.cursor()

    def _create_connection(self):
        return sqlite3.connect(DATABASE)

    def execute(self, query, data):
        if data is None:
            self._cursor.execute(query)
        else:
            self._cursor.execute(query, data)
        self._connection.commit()
        return self._cursor

MockDataAccess:

class MockDataAccess(DataAccess):
    def __init__(self):
        super(MockDataAccess, self).__init__()
        self._connection = self._create_connection()

    def _create_connection(self):
        return sqlite3.connect(TEST_DATABASE)

I'm new to testing, so don't know if this would've been obvious - thought I'd share in case this does help someone.

Charlie Clarke
  • 177
  • 1
  • 9