0

I've been using the torchdata library (v0.6.0) to construct datapipes for my machine learning model, but I can't seem to figure out how torchdata expects its users to make a train/test split.

Supposing I have a datapipe dp, my first attempt was to use the Sampler datapipe along with a torch.utils.data.SubsetRandomSampler (which is what I expected from this part of the documentation), but this doesn't work how I would've thought:

>>> dp = SequenceWrapper(range(5))
>>> Sampler(dp,SubsetRandomSampler([0, 1, 2]))
Traceback (most recent call last):
...
TypeError: 'SubsetRandomSampler' object is not callable

Maybe torchdata has it's own samplers I'm not familiar with.

The only other way I can think of doing this would be to use a Demultiplexer, but this feels unclean to me, because we have to enumerate then "de-enumerate":

>>> train_len = len(dp) * 0.8
>>> dp1, dp2 = dp.enumerate().demux(num_instances=2, classifier_fn=lambda x: x[0] >= train_len)
>>> dp1, dp2 = (d.map(lambda x: x[1]) for d in (dp1, dp2))

Is there an "intended" way of doing this with torchdata which I'm missing?

user3002473
  • 4,835
  • 8
  • 35
  • 61

1 Answers1

1

PyTorch's tutorial on using DataPipes answers the question:

import torchdata.datapipes.iter as pipes
from torch.utils.data import DataLoader, random_split

# initialize DataPipe with dummy values
dp = pipes.IterableWrapper(range(5))

# create train/test split ratio sizes (assuming 80/20 split)
train_size, test_test = int(len(dp) * 0.8), len(dp) - (int(len(dp) * 0.8))

# split dataset into train/test sets
train_dataset, test_dataset = random_split(dp, [train_size, test_size])

# create batch sizes for train and test dataloaders
# (loading everything into memory, no minibatches)
batch_train, batch_test = len(train_dataset), len(test_dataset)

# create train and test dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_train, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_test)

# train model
for i, j in train_dataloader:
    ...
    preds = model(i)
    loss = loss_fn(preds, j)
    ....

If you want to use the built-in random_split() method of Iterable-style DataPipe:

train_dataset, test_dataset = dp.random_split(total_length=len(dp), weights={"train": 0.8, "test": 0.2}, seed=42)

train_dataloader = DataLoader(train_dataset, batch_size=batch_train, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_test)

Edit: You can directly access the DataPipe from within the split dataset (this works with both IterDataPipe and MapDataPipe:

train_dp = train_dataset.dataset
test_dp = test_dataset.dataset

If you want the output of the random_split() function to be a MapDataPipe, you can always wrap the outputs in SequenceWrapper():

from torchdata.datapipes.map import SequenceWrapper

train_dataset, test_dataset = random_split(dp, [train_size, test_size])
train_mdp = SequenceWrapper(train_dataset)
test_mdp = SequenceWrapper(test_dataset)

And same idea with IterDataPipe:

train_dataset, test_dataset = random_split(dp, [train_size, test_size])
train_idp = pipes.IterableWrapper(train_dataset)
test_idp = pipes.IterableWrapper(test_dataset)
Djinn
  • 663
  • 5
  • 12
  • sorry to deselect your answer, but I just realized the `random_split` method only exists for `IterDataPipe`s, not `MapDataPipe`s, the latter of which I work with more often. – user3002473 May 05 '23 at 00:10
  • It's the exact same solution. Did you not try to use the answer? I've tried it with `Iterative DataPipe` and with `Map DataPipe`. If you meant to ask another question, then make another question, but my solution (which is based on the official PyTorch solution, using the minimal reproducible example that /you/ provided) answers your question. Respectively, run that answer as accepted. – Djinn May 05 '23 at 06:22
  • ah I see, so the first solution does indeed work even for MapDataPipes, whereas they lack a `random_split` method themselves. I was just hoping that the result of the random_split would again be two datapipes rather than torch `Dataset`s. Thank you for the help! I didn't mean to offend by temporarily deselecting your answer, I was merely looking to open a discussion :) – user3002473 May 05 '23 at 13:39
  • A discussion is perfectly fine, I'd love to have one :) it's just that by deselecting the answer that actually offered the solution could cause people with that exact question to skip over it because it may or may not work. A discussion and an accepted answer aren't mutually exclusive. No offense received over here :) – Djinn May 05 '23 at 14:34
  • If you want the outputs of the random split to be MapDataPipes, just wrap the outputs in `SequenceWrapper()`. I've updating the answer to include the code. – Djinn May 05 '23 at 14:35