Say I have an SQLalchemy model like:
from app.data_structures.base import (
Base
class User(Base):
__tablename__ = "users"
user_name: Mapped[str] = mapped_column(primary_key=True, nullable=True)
flag = Column(Boolean)
def __init__(
self,
user_name: str = None,
flag: bool = false(),
)-> None:
self.user_name = user_name
self.flag = flag
where app.data_structures.base.py:
from contextlib import contextmanager
from os import environ
from os.path import join, realpath
from sqlalchemy import Column, ForeignKey, Table, create_engine
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
db_name = environ.get("DB_NAME")
ROOT_DIR =environ.get("ROOT_DIR")
db_path = realpath(join(ROOT_DIR, "data", db_name))
engine = create_engine(f"sqlite:///{db_path}", connect_args={"timeout": 120})
session_factory = sessionmaker(bind=engine)
sql_session = scoped_session(session_factory)
@contextmanager
def Session():
session = sql_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
I then have a function defined (app.helpers.db_helper.get_users_to_query), elsewhere that does something like:
from app.data_structures.base import Session
def get_usernames_to_query(flag: bool = False) -> List[User]:
logger = getLogger(__name__)
try:
with Session() as session:
usernames_to_query = [
{"user_name": user.user_name}
for user in session.query(User).filter(
User.flag == flag
)
]
except Exception as err:
logger.exception(f"Exception thrown in get_users_to_query {' '.join(err.args)}")
usernames_to_query = []
return usernames_to_query
I am trying to unit test this function, by having a separate session in memory, rather than using the production DB.
I patch the Session() context manager where the function imports it, and then mock the __enter__
return value to be the in memory session:
@patch("app.helpers.db_helper.Session")
def test_get_users_to_query(self, mock_session) -> None:
self.engine = create_engine("sqlite:///:memory:")
self.session = Session(self.engine)
print(type(mock_session))
print("session id", id(self.session))
mock_session.__enter__ = Mock(return_value=self.session)
mock_session.__exit__ = Mock(return_value=None)
Base.metadata.create_all(self.engine)
(patcher, environ_dict, environ_mock_get) = self.environ_mock_get_factory()
with patcher():
fake_user1 = User(
display_name="Jane Doe", user_name="jdb1", flag=False
)
fake_user2 = User(
display_name="John Doe", user_name="jdb2", flag=True
)
with mock_session as session:
session.add(fake_user1)
session.add(fake_user2)
session.commit()
users = session.query(User).filter(User.flag == True).all()
print([user.user_name for user in users])
print(id(session), type(session))
users_to_query = get_users_to_query()
print("Users:", users_to_query)
In the test above the mocked_session context manger get replaced with self.session, as I would expect. So print([user.user_name for user in users])
prints the fake_user2's username as expected.
But in the call to get_users_to_query
it is not and it doesn't query the in memory DB, and get_users_to_query
returns an empty list.
Can anyone explain why this is, and how to get the session in get_users_to_query
to be the mocked session with the in memory DB (i.e. self.session)?
Thanks!