0

I have a client and server that communicate with GraphQL over websocket transport. Everything is written in python: On the client side I use the gql package, and on the server side I use the strawberry packge.
I want to add a custom header for the response of a GQL endpoint. The need is to send a deprecation message to the client on run-time, that depends on the arguments the endpoints gets. Namely, the endpoint itself is not deprecated, but I want to change it's API a little. Therefore, I want to keep support old client that call that endpoint, but raise a warning for the users of old clients that they should update their client. clients on the client side.
On the client side (the old one), my code look like this:

@websocket_exception_handling
async def list_metadata_by_version(self,
    returned_fields: List[str],
    version: Union[Versions, str] = LATEST_VERSION,
    skip: int = 0,
    limit: int = None
    ) -> list:
    """ Get all the metadata objects of a specific version.
    If no version is supplied, the default is the latest.
    This function is asynchronous and should be called with `await` or with other proper library for running asynchronous code.

    Args:
        returned_fields (List[str]): the desired fields from the model, for example ["session_name","clip_name"], etc.
        version (Union[Versions, str], optional): The version of the artifact. Notice that `local_latest` version is not a valid value in this context.
        skip (int, optional): Use this if you want to skip the first artifacts in the DB.
        limit (int, optional): Use this if you want to limit the amount artifacts that will be returned from the DB. In that case you'll get the first `x` artifacts from the catalog.
        
        Returns:
            list: List of metadata objects.
    """
    assert bool(self.NESTED_DSL_SCHEMAS_NAMES), "The object doesn't have a defined schemas"
    assert limit is None or limit > 0, "Limit must be None or positive integer"
    batch_size = limit if limit is not None and limit < self.BATCH_SIZE_THRESHOLD else self.BATCH_SIZE_THRESHOLD
    final_list = []
    async with self.gql_client as gql_session:
        while True:
            kwargs = {"version": version, "skip": skip, "limit": batch_size}
            if getattr(self, "OBJECT_TYPE", None):
                kwargs["gtoType"] = getattr(self, "OBJECT_TYPE", None)
            where_clause = self.gql_dsl_schema.Query.listMetadata(**kwargs)
            sdl_schemas = [getattr(self.gql_dsl_schema, scheme_name) for scheme_name in self.NESTED_DSL_SCHEMAS_NAMES]
            select_clause = where_clause.select(*[field_name_to_DSLfield(field.split("."), sdl_schemas) for field in returned_fields])
            query = dsl_gql(DSLQuery(select_clause))
            query_results = await gql_session.execute(query)
            query_results = transform_dict_keys(query_results['listMetadata'], camel_to_snake)
            if len (query_results) == 0:
                break
            skip = skip + batch_size
            final_list.extend([self._serialize_metadata_from_dict(doc) for doc in query_results])
            if limit is not None and len(final_list) >= limit:
                break
    if not final_list:
        raise DataNotExits("There is no data that meet those parameters")
    return final_list

Where self.gql_client is from the type gql.client.Client, and gql_session is from the type gql.client.AsyncClientSession. The GraphQL scheme is:

type Query {
  """Gets gto metadata by version"""
  listMetadata(gtoType: GTOType!, version: String!, skip: Int!, limit: Int!): [GTOMetadataWithArtifactsType!]!
}

The change we did in the server API is change the argument gto_type to data_type, for our own purposes. The schemes of GTOMetadataWithArtifactsType and GTOType are less interesting.
On the server side (the new one), my code is:

from typing import List, Optional
import strawberry
from core.models.mongodb.gto_metadata import GTOMongoClient
from core.models.redis.rate_limit import user_rate_limit
from core.schemas.gto_metadata import (GTOArtifact, GTOMetadataWithArtifacts,
                                       GTOType, GTOSource)
from core.schemas.base_metadata import Versions as GTOVersions
from strawberry.fastapi import GraphQLRouter
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from strawberry.types import Info
from starlette.responses import Response


async def get_context():
    return {
        "db_client": GTOMongoClient(),
    }


@strawberry.experimental.pydantic.type(model=GTOArtifact, all_fields=True)
class GTOArtifactType:
    pass

# https://github.com/strawberry-graphql/strawberry/issues/1598
strawberry.enum(GTOType)
strawberry.enum(GTOSource)

@strawberry.experimental.pydantic.type(
    model=GTOMetadataWithArtifacts, all_fields=True
)
class GTOMetadataWithArtifactsType:
    pass


@strawberry.type
class Query:
    @strawberry.field(description="Gets gto metadata by version")
    @user_rate_limit()
    async def list_metadata(self, info: Info, version: str, skip: int, limit: int, data_type: Optional[GTOType] = None, gto_type: Optional[GTOType] = None) -> List[GTOMetadataWithArtifactsType]:
        if gto_type:  # TODO find how to extract this message in the client side
            info.context["response"].headers["Deprecated"] = "deprecation msg"
            # response: Response = info.context["response"]
            # response.set_cookie(key="Deprecated", value="deprecation msg")
        data_type = data_type or gto_type
        db_client: GTOMongoClient = info.context['db_client']
        if version == GTOVersions.latest:
            version = await db_client.get_latest_version(data_type)
        results = await db_client.list_metadata_by_version(data_type, version, skip, limit)
        return results


gto_gql_schema = strawberry.Schema(query=Query)
gto_graphql_route = GraphQLRouter(
    gto_gql_schema, context_getter=get_context, subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])

references:

  1. gql package - https://gql.readthedocs.io/en/latest/index.html)
  2. strawberry package - https://strawberry.rocks/
  3. gql.client.Client - https://gql.readthedocs.io/en/latest/modules/client.html#gql.client.Client

As you can see in the server side's code, I've tried to add the "deprecated" header to the response, using the Info object that injected for me out-of-the-box by the strawberry package, and contains information for the current execution context (https://strawberry.rocks/docs/types/resolvers#accessing-execution-information). I expected that this header will appear in gql_session.transport.response_headers, gql_session.client.transport.response_headers, gql_session.transport.websocket.response_headers or gql_session.client.transport.websocket.response_headers , but all of them contains:

Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: 6bpnsJ5ilB8zqX1Mvyd386KwSD4=
Sec-WebSocket-Extensions: permessage-deflate
Sec-WebSocket-Protocol: graphql-transport-ws
Date: Thu, 02 Mar 2023 08:18:47 GMT
Server: Python/3.9 websockets/10.4

I can change the server side under-the-hood implementation, but I cannot change the server API for it might break the support in old client. I know that the change should probably be on the client side, so it won't effect the old clients anyway, but this is fine since I want to develop a tool that will help me with future changes like this one. I assume that the changes I'm doing in the Info object on the server side doesn't effect the client side. But it seem strange since the Info object should be an out-of-the-box instruments to get the request and the response objects. Another thing I suspect is that I'm using the await gql_session.execute(query) function wrong, and I should supply it another argument that will cause it return a response object (right now it returns a dict with the results only).

0 Answers0