0

I am playing around with using pydantic/SQLModels to specify query parameters for FastAPI endpoints. The model would have all optional types, and then get passed the endpoint using query_params:QuerySchema=Depends(QuerySchema). This works fine but I am having issues distinguishing between values that are unset and values that the user has explicitly set to None as to query for Null values.

I have tried using .dict(exclude_unset==True) model but it seems FastAPI is doing some magic behind the scenes to set optional params to None which means exclude_unset doesn't exclude anything from the returned dictionary.

Anyone have and ideas on how to approach this? I've include an example of what I'm trying below.

crud/base.py

import logging
from datetime import datetime
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from uuid import UUID

from app.models.base_models import _SQLModel
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select

ModelType = TypeVar("ModelType", bound=_SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
# SchemaType = TypeVar("SchemaType", bound=BaseModel)
T = TypeVar("T", bound=SQLModel)

IdType = Union[str, UUID, int]


class SessionNotFoundException(BaseException):
    ...


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
    def __init__(
        self,
        model: type[ModelType],
        session_factory: Optional[Callable[[], AsyncSession]] = None,
    ) -> None:
        self.model: type[ModelType] = model
        self.session_factory = session_factory
        self._logger = logging.getLogger(str(self.model.__tablename__))

    @property
    def _local_session(self) -> AsyncSession:
        if not self.session_factory:
            raise SessionNotFoundException(
                "No session provided or session_factory initialized."
            )
        return self.session_factory()

    ### OMITTED FOR LENGTH

    async def get_multi(
        self,
        *,
        offset: int = 0,
        limit: Optional[int] = 100,
        query: Optional[Select[T]] = None,
        db_session: Optional[AsyncSession] = None,
    ) -> list[ModelType]:


        session = db_session or self._local_session
        statement = query if query is not None else select(self.model)

        statement = (
            statement.offset(offset).limit(limit).order_by(self.model.created_at)
        )
        response = await session.execute(statement)
        return response.scalars().all()


### OMITTED FOR LENGTH


Usage example. This doesn't work. I think because FastAPI is setting unset optional params to none, so the exclude unset option doesn't exclude anything.

crud/sports_crud.py

from typing import Optional, Union

from app.models import Sport
from app.schemas.sports import ISportCreate, ISportUpdate
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.db.session import SessionLocal
from .base import CRUDBase, IdType


class CRUDSports(CRUDBase[Sport, ISportCreate, ISportUpdate]):
    ...


sports = CRUDSports(Sport, session_factory=SessionLocal)

router

@sports_router_v1.get("", response_model=list[ISportRead])
async def get_sports(
    *,
    query_params: ISportQuery = Depends(ISportQuery),
    page_params: PageParams = Depends(PageParams),
) -> list[Sport]:
    # Get explicitly set query params
    query = query_params.dict(exclude_unset=True) # debugger shows this includes all keys with None values

    # build select query
    stmt = select(Sport)
    if query != {}:
        for k, v in query.items():
            stmt = stmt.where(getattr(Sport, k) == v)

    sports = await crud.sports.get_multi(query=stmt, **page_params.dict())
    return sports

schemas/sports.py

def optional(*fields):
    # https://github.com/pydantic/pydantic/issues/1223
    # https://github.com/pydantic/pydantic/pull/3179

    def dec(_cls):
        for field in fields:
            _cls.__fields__[field].required = False
        return _cls

    if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel):
        cls = fields[0]
        fields = cls.__fields__
        return dec(cls)
    return dec


class ISportRead(SportBase):
    id: UUID


@optional
class ISportQuery(ISportRead):
    ...

Andrew L.
  • 23
  • 8
  • 1
    Your question and code are quite lengthy, you're more likely to get positive engagement and answers if you [edit] it and reduce it to a simple explanation and [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – ljmc Dec 30 '22 at 10:49

0 Answers0