0

I am using Dask with deep learning models and large arrays. In particular, I am trying to get predictions using map_blocks

array.map_blocks(model.predict)

Since serialization takes time, is there a way to avoid model to be serialized for each call?

M_x
  • 782
  • 1
  • 8
  • 26
mdrio
  • 11
  • 2

1 Answers1

0

At least two broad options:

  1. Scatter (as suggested by @Nick Becker in the comments above), rough pseudocode:
fut_model = client.scatter(model)
array.map_blocks(lambda x: fut_model.predict(x))
  1. Wrap the prediction in a function that will load the required component, here's the pseudocode:
def model_predict(X_chunk):
    model = pickle.load(mypath) # or another loading method
    return model.predict(X_chunk)

array.map_blocks(model_predict)

You might also find these answers relevant.

SultanOrazbayev
  • 14,900
  • 3
  • 16
  • 46