0

How can I properly serialize metpy units (based on pint) to work with dask distributed? As far as I understand, it looks like dask distributed automatically pickles data for ease of transfer, but fails to pickle the metpy units which is necessary for computation. Error produced: TypeError: cannot pickle 'weakref' object. MWE below.

import metpy.calc as mpcalc
from metpy.units import units
from dask.distributed import Client, LocalCluster

def calculate_dewpoint(vapor_pressure):
    
    dewpoint = mpcalc.dewpoint(vapor_pressure * units('hPa'))
    
    return dewpoint


cluster = LocalCluster()
client = Client(cluster)

## works 
vapor_pressure = 5
dp = calculate_dewpoint(vapor_pressure)
print(dp)

## doesn't work
vapor_pressure = 5
dp_future = client.submit(calculate_dewpoint, vapor_pressure)
dp = dp_future.result()

EDIT: Added full traceback.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/worker.py in dumps_function(func)
   4271         with _cache_lock:
-> 4272             result = cache_dumps[func]
   4273     except KeyError:

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/utils.py in __getitem__(self, key)
   1362     def __getitem__(self, key):
-> 1363         value = super().__getitem__(key)
   1364         self.data.move_to_end(key)

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/collections/__init__.py in __getitem__(self, key)
   1009             return self.__class__.__missing__(self, key)
-> 1010         raise KeyError(key)
   1011     def __setitem__(self, key, item): self.data[key] = item

KeyError: <function calculate_dewpoint at 0x2ad5e010f0d0>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/protocol/pickle.py in dumps(x, buffer_callback, protocol)
     52                 buffers.clear()
---> 53                 result = cloudpickle.dumps(x, **dump_kwargs)
     54         elif not _always_use_pickle_for(x) and b"__main__" in result:

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
     72             )
---> 73             cp.dump(obj)
     74             return file.getvalue()

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    601         try:
--> 602             return Pickler.dump(self, obj)
    603         except RuntimeError as e:

TypeError: cannot pickle 'weakref' object

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/glade/scratch/cbecker/ipykernel_272346/952144406.py in <module>
     20 ## doesn't work
     21 vapor_pressure = 5
---> 22 dp_future = client.submit(calculate_dewpoint, vapor_pressure)
     23 dp = dp_future.result()

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/client.py in submit(self, func, key, workers, resources, retries, priority, fifo_timeout, allow_other_workers, actor, actors, pure, *args, **kwargs)
   1577             dsk = {skey: (func,) + tuple(args)}
   1578 
-> 1579         futures = self._graph_to_futures(
   1580             dsk,
   1581             [skey],

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/client.py in _graph_to_futures(self, dsk, keys, workers, allow_other_workers, priority, user_priority, resources, retries, fifo_timeout, actors)
   2628             # Pack the high level graph before sending it to the scheduler
   2629             keyset = set(keys)
-> 2630             dsk = dsk.__dask_distributed_pack__(self, keyset, annotations)
   2631 
   2632             # Create futures before sending graph (helps avoid contention)

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/dask/highlevelgraph.py in __dask_distributed_pack__(self, client, client_keys, annotations)
   1074                     "__module__": layer.__module__,
   1075                     "__name__": type(layer).__name__,
-> 1076                     "state": layer.__dask_distributed_pack__(
   1077                         self.get_all_external_keys(),
   1078                         self.key_dependencies,

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/dask/highlevelgraph.py in __dask_distributed_pack__(self, all_hlg_keys, known_key_dependencies, client, client_keys)
    432             for k, v in dsk.items()
    433         }
--> 434         dsk = toolz.valmap(dumps_task, dsk)
    435         return {"dsk": dsk, "dependencies": dependencies}
    436 

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cytoolz/dicttoolz.pyx in cytoolz.dicttoolz.valmap()

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cytoolz/dicttoolz.pyx in cytoolz.dicttoolz.valmap()

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/worker.py in dumps_task(task)
   4308             return d
   4309         elif not any(map(_maybe_complex, task[1:])):
-> 4310             return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])}
   4311     return to_serialize(task)
   4312 

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/worker.py in dumps_function(func)
   4272             result = cache_dumps[func]
   4273     except KeyError:
-> 4274         result = pickle.dumps(func, protocol=4)
   4275         if len(result) < 100000:
   4276             with _cache_lock:

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/distributed/protocol/pickle.py in dumps(x, buffer_callback, protocol)
     58         try:
     59             buffers.clear()
---> 60             result = cloudpickle.dumps(x, **dump_kwargs)
     61         except Exception as e:
     62             logger.info("Failed to serialize %s. Exception: %s", x, e)

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
     71                 file, protocol=protocol, buffer_callback=buffer_callback
     72             )
---> 73             cp.dump(obj)
     74             return file.getvalue()
     75 

/glade/work/cbecker/miniconda3/envs/risk/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    600     def dump(self, obj):
    601         try:
--> 602             return Pickler.dump(self, obj)
    603         except RuntimeError as e:
    604             if "recursion" in e.args[0]:

TypeError: cannot pickle 'weakref' object
DopplerShift
  • 5,472
  • 1
  • 21
  • 20
bwc
  • 1,028
  • 7
  • 18
  • metpy doesn't support dask yet. this issue is relevant: https://github.com/Unidata/MetPy/issues/1479 – Michael Delgado Aug 13 '22 at 02:37
  • MetPy is interested in working with Dask, so I'm curious how this is actually failing. Can you please post the full error traceback so we can see where it's failing? From your sample code, it's unclear where `weakref` is even involved here. – DopplerShift Aug 15 '22 at 21:53
  • @DopplerShift Added traceback. I found somewhat of a hack to get around it by importing `metpy.units` during each function, which will work if you only send magnitudes from function to function and never units (which dask can't pickle properly). – bwc Aug 15 '22 at 22:12

1 Answers1

0

So there's an issue where (I think) it's trying to serialize the unit registry or units and transfer them between processes. To work around this, try moving the import of units inside the function (though this might cause some other problems):

def calculate_dewpoint(vapor_pressure):
    from metpy.units import units
    return mpcalc.dewpoint(vapor_pressure * units('hPa'))
DopplerShift
  • 5,472
  • 1
  • 21
  • 20