I am working on building a computation graph with Dask. Some of the intermediate values will be used multiple times, but I would like those calculations to only run once. I must be making a trivial mistake, because that's not what happens. Here is a minimal example:
In [1]: import dask
dask.__version__
Out [1]: '1.0.0'
In [2]: class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
In [3]: sg = SumGenerator()
In [4]: @dask.delayed
def source1():
return 1.
@dask.delayed
def source2():
return 2.
@dask.delayed
def source3():
return 3.
In [5]: sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
In [6]: sg.generate().visualize()
Sadly I am unable to post the resulting graph image, but basically I see two separate nodes for the function source1
that was registered twice. Therefore the function is called twice. I would rather like to have it called once, the result remembered and added twice in the sum. What would be the correct way to do that?