0

I want to serve an application that processes data within the googles JAX framwork with flask and gunicorn.

If run inside flask, everything works fine. As soon as I run the application within gunicorn, every jax-related part results in the worker process dying without any exception being raised. I tried using both sync and gthreads as workers, but with the same result.

I have tried to see if JAX can handle multiprocessing and multithreading by wrapping the same calls inside ThreadPoolExecutor and ProcessPoolExecutor, and that works flawlessly.

import jax

import logging
logging.basicConfig(format="%(asctime)s | %(name)12.12s | %(message)s")
logger = logging.getLogger("Main")
logger.setLevel(logging.DEBUG)

from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed

from fit.optimization.vectorize import BatchNumpyInterface, batch_calculate_fit

def warmup():
    logger.debug("Warmup")
    data = BatchNumpyInterface.generate_dummy()
    batch_calculate_fit(data)
    logger.debug("Warmed up")

def run_fn():
    logger.debug("Creating data")
    data = BatchNumpyInterface.generate_dummy(100)
    
    logger.debug("Predicting %s in batches", 100)
    result = batch_calculate_fit(data)

    logger.debug("Done")
    return float(result[0][0]), float(result[1][0])

#with ThreadPoolExecutor(max_workers=4) as executor:
with ProcessPoolExecutor(max_workers=4) as executor:
    results = []
    for i in range(4):
        results.append(executor.submit(warmup))

    for res in as_completed(results):
        continue

    results = []
    for i in range(10):
        future = executor.submit(run_fn)
        results.append(future)

    for res in as_completed(results):
        print(res.result())


During debugging, every time I inspect a JAX DeviceArray, the application crashes. Same goes for stepping over the first calculation with JAX.

Any help would be much appreciated!

Flo Win
  • 154
  • 10
  • https://github.com/google/jax/issues/3691 might be relevant. – jakevdp Nov 02 '20 at 15:14
  • Thanks, but I think that only applies to forking objects. For me that is not the case, the objects get passed into the forked processes. The only thing that I can think of are forking some xla related stuff that gets initialized beforehand – Flo Win Nov 02 '20 at 19:47

0 Answers0