I am playing around with using pydantic/SQLModels to specify query parameters for FastAPI endpoints. The model would have all optional types, and then get passed the endpoint using query_params:QuerySchema=Depends(QuerySchema)
. This works fine but I am having issues distinguishing between values that are unset and values that the user has explicitly set to None
as to query for Null values.
I have tried using .dict(exclude_unset==True)
model but it seems FastAPI is doing some magic behind the scenes to set optional params to None
which means exclude_unset doesn't exclude anything from the returned dictionary.
Anyone have and ideas on how to approach this? I've include an example of what I'm trying below.
crud/base.py
import logging
from datetime import datetime
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from uuid import UUID
from app.models.base_models import _SQLModel
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select
ModelType = TypeVar("ModelType", bound=_SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
# SchemaType = TypeVar("SchemaType", bound=BaseModel)
T = TypeVar("T", bound=SQLModel)
IdType = Union[str, UUID, int]
class SessionNotFoundException(BaseException):
...
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(
self,
model: type[ModelType],
session_factory: Optional[Callable[[], AsyncSession]] = None,
) -> None:
self.model: type[ModelType] = model
self.session_factory = session_factory
self._logger = logging.getLogger(str(self.model.__tablename__))
@property
def _local_session(self) -> AsyncSession:
if not self.session_factory:
raise SessionNotFoundException(
"No session provided or session_factory initialized."
)
return self.session_factory()
### OMITTED FOR LENGTH
async def get_multi(
self,
*,
offset: int = 0,
limit: Optional[int] = 100,
query: Optional[Select[T]] = None,
db_session: Optional[AsyncSession] = None,
) -> list[ModelType]:
session = db_session or self._local_session
statement = query if query is not None else select(self.model)
statement = (
statement.offset(offset).limit(limit).order_by(self.model.created_at)
)
response = await session.execute(statement)
return response.scalars().all()
### OMITTED FOR LENGTH
Usage example. This doesn't work. I think because FastAPI is setting unset optional params to none, so the exclude unset option doesn't exclude anything.
crud/sports_crud.py
from typing import Optional, Union
from app.models import Sport
from app.schemas.sports import ISportCreate, ISportUpdate
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.db.session import SessionLocal
from .base import CRUDBase, IdType
class CRUDSports(CRUDBase[Sport, ISportCreate, ISportUpdate]):
...
sports = CRUDSports(Sport, session_factory=SessionLocal)
router
@sports_router_v1.get("", response_model=list[ISportRead])
async def get_sports(
*,
query_params: ISportQuery = Depends(ISportQuery),
page_params: PageParams = Depends(PageParams),
) -> list[Sport]:
# Get explicitly set query params
query = query_params.dict(exclude_unset=True) # debugger shows this includes all keys with None values
# build select query
stmt = select(Sport)
if query != {}:
for k, v in query.items():
stmt = stmt.where(getattr(Sport, k) == v)
sports = await crud.sports.get_multi(query=stmt, **page_params.dict())
return sports
schemas/sports.py
def optional(*fields):
# https://github.com/pydantic/pydantic/issues/1223
# https://github.com/pydantic/pydantic/pull/3179
def dec(_cls):
for field in fields:
_cls.__fields__[field].required = False
return _cls
if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel):
cls = fields[0]
fields = cls.__fields__
return dec(cls)
return dec
class ISportRead(SportBase):
id: UUID
@optional
class ISportQuery(ISportRead):
...