I've been using the torchdata
library (v0.6.0) to construct datapipes for my machine learning model, but I can't seem to figure out how torchdata
expects its users to make a train/test split.
Supposing I have a datapipe dp
, my first attempt was to use the Sampler
datapipe along with a torch.utils.data.SubsetRandomSampler
(which is what I expected from this part of the documentation), but this doesn't work how I would've thought:
>>> dp = SequenceWrapper(range(5))
>>> Sampler(dp,SubsetRandomSampler([0, 1, 2]))
Traceback (most recent call last):
...
TypeError: 'SubsetRandomSampler' object is not callable
Maybe torchdata has it's own samplers I'm not familiar with.
The only other way I can think of doing this would be to use a Demultiplexer
, but this feels unclean to me, because we have to enumerate then "de-enumerate":
>>> train_len = len(dp) * 0.8
>>> dp1, dp2 = dp.enumerate().demux(num_instances=2, classifier_fn=lambda x: x[0] >= train_len)
>>> dp1, dp2 = (d.map(lambda x: x[1]) for d in (dp1, dp2))
Is there an "intended" way of doing this with torchdata
which I'm missing?