0

I am working on a hook that will process a batch of http requests asynchronously.

The hook should be able to do the following:

  • Send requests concurrently up to a max concurrency
  • Process requests in the order of the original request list
  • Raise an exception when an error threshold is reached, i.e. if error_threshold = 1, raise exception after a single failed request, block future requests

I am struggling to make my hook both asynchronous and processing requests in order, but stop processing tasks when there are errors. I've tried various approaches like asyncio.as_completed, asyncio.gather, asyncio.wait, etc. but I always have to sacrifice one requirement.

Below is my best stab so far. I use asyncio.as_completed and create_task, so max_concurrency and order are working fine...but error handling is broken, and the hook continues to process requests after an exception is supposedly raised.

I'd appreciate some guidance here, as I'm brand new to async programming in python.

import asyncio
import logging
from typing import Callable, List

from airflow.hooks.base import BaseHook
from httpx import AsyncClient, Auth, Headers, HTTPStatusError, Response
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed
from tqdm import tqdm


class AsyncHttpHook(BaseHook):
    """
    Hook for sending async HTTP requests to a REST API using the httpx library.
    """

    def __init__(
        self,
        method: str = None,
        headers: Headers = None,
        auth: Auth = None,
        base_url: str = None,
        attempts: int = 3,
        retry_wait: float = 1,
        error_threshold: int = None,
        max_concurrency: int = 1,
    ):
        self.method = method
        self.headers = headers
        self.auth = auth
        self.base_url = base_url
        self.attempts = attempts
        self.retry_wait = retry_wait
        self.error_threshold = error_threshold
        self.max_concurrency = max_concurrency
        self.error_count = asyncio.Lock()
        self.semaphore = asyncio.Semaphore(max_concurrency)

    async def send_async_request(
        self,
        client: AsyncClient,
        endpoint: str,
        method: str,
        error_limit_callback: Callable,
    ) -> Response:
        """
        Send an asynchronous request.
        If an error occurs, the method will retry the request up to max_error_count times.
        If the error count is equal to max_error_count, the method will call error_limit_callback.
        If error_limit_callback is None, the method will raise an exception.

        :param client: The HTTP client
        :type client: httpx.AsyncClient
        :param endpoint: The request URL
        :type endpoint: str
        :param method: The HTTP method
        :type method: str
        :param error_limit_callback: The callback function to call when the error count reaches the error threshold
        :type error_limit_callback: Callable

        :return: The response
        :rtype: httpx.Response
        """
        async with self.semaphore:
            try:
                async for attempt in AsyncRetrying(
                    wait=wait_fixed(self.retry_wait),
                    stop=stop_after_attempt(self.attempts),
                    reraise=True,
                ):
                    with attempt:
                        print(endpoint)
                        response = await client.request(
                            method=method,
                            url=endpoint,
                        )
                        response.raise_for_status()
                        return response
            except HTTPStatusError as e:
                logging.error(
                    'Received %s status: "%s" for url %s',
                    e.response.status_code,
                    e.response.text,
                    e.response.url,
                )
                async with self.error_count:
                    self.error_count._value += 1
                    if (
                        self.error_threshold
                        and self.error_count._value >= self.error_threshold
                    ):
                        if error_limit_callback:
                            error_limit_callback(e)
                        logging.error(
                            f"Error count reached max error count: {self.error_count._value}"
                        )
                        raise e
                # If error threshold is not reached, return the failed response
                return response

    async def prepare_async_requests(
        self,
        endpoints: List,
        method: str,
        headers: dict,
        error_limit_callback: Callable,
    ) -> List[Response]:
        """
        Prepare a list of requests asynchronously. This method will not send the requests.

        :param endpoints: The list of request URLs
        :type endpoints: List[str]
        :param method: The HTTP method
        :type method: str
        :param headers: The HTTP headers
        :type headers: dict
        :param error_limit_callback: The callback function to call when the error count reaches the error threshold
        :type error_limit_callback: Callable

        :return: The list of responses
        :rtype: List[Response]
        """
        async with AsyncClient(
            timeout=None, auth=self.auth, headers=headers, base_url=self.base_url
        ) as client:
            task_list = [
                asyncio.create_task(
                    self.send_async_request(
                        client, endpoint, method, error_limit_callback
                    )
                )
                for endpoint in endpoints
            ]
            responses = [
                await task
                for task in tqdm(
                    asyncio.as_completed(task_list),
                    total=len(task_list),
                    desc=f"Sending requests using max concurrency: {self.max_concurrency}...",
                )
            ]
            return responses

    def run(
        self,
        endpoints: List[str],
        method: str = None,
        headers: dict = None,
        error_limit_callback: Callable = None,
    ) -> List[Response]:
        """
        Send a list of requests asynchronously.
        Resets error count to 0, so be wary of calling this method in different threads.

        :param endpoints: The list of endpoints
        :type endpoints: List[str]
        :param method: The HTTP method
        :type method: str
        :param headers: The HTTP headers
        :type headers: dict
        :param error_limit_callback: The callback function to call when the error count reaches the error threshold
        :type error_limit_callback: Callable

        :return: The list of responses
        :rtype: List[httpx.Response]
        """
        self.error_count._value = 0
        if not method:
            method = self.method
        if not headers:
            headers = self.headers
        loop = asyncio.get_event_loop()
        responses = loop.run_until_complete(
            self.prepare_async_requests(
                endpoints,
                method,
                headers,
                error_limit_callback,
            )
        )
        return responses

if __name == '__main__':
    """
    Example run
    """
    hook = AsyncHttpHook(
        base_url="https://example.com",
        max_concurrency=5,
        error_threshold=1,
        )
    hook.run(['test/1', 'test/2', 'test/3', 'test/4', 'test/5'])
JTa
  • 181
  • 1
  • 12

0 Answers0