5

I'm working with text and use torchtext.data.Dataset. Creating the dataset takes a considerable amount of time. For just running the program this is still acceptable. But I would like to debug the torch code for the neural network. And if python is started in debug mode, the dataset creation takes roughly 20 minutes (!!). That's just to get a working environment where I can debug-step through the neural network code.

I would like to save the Dataset, for example with pickle. This sample code is taken from here, but I removed everything that is not necessary for this example:

from torchtext import data
from fastai.nlp import *

PATH = 'data/aclImdb/'

TRN_PATH = 'train/all/'
VAL_PATH = 'test/all/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

TEXT = data.Field(lower=True, tokenize="spacy")

bs = 64;
bptt = 70

FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

with open("md.pkl", "wb") as file:
    pickle.dump(md, file)

To run the code, you need the aclImdb dataset, it can be downloaded from here. Extract it into a data/ folder next to this code snippet. The code produces an error in the last line, where pickle is used:

Traceback (most recent call last):
  File "/home/lhk/programming/fastai_sandbox/lesson4-imdb2.py", line 27, in <module>
    pickle.dump(md, file)
TypeError: 'generator' object is not callable

The samples from fastai often use dill instead of pickle. But that doesn't work for me either.

lhk
  • 27,458
  • 30
  • 122
  • 201

4 Answers4

2

I came up with the following functions for myself:

import dill
from pathlib import Path

import torch
from torchtext.data import Dataset

def save_dataset(dataset, path):
    if not isinstance(path, Path):
        path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    torch.save(dataset.examples, path/"examples.pkl", pickle_module=dill)
    torch.save(dataset.fields, path/"fields.pkl", pickle_module=dill)

def load_dataset(path):
    if not isinstance(path, Path):
        path = Path(path)
    examples = torch.load(path/"examples.pkl", pickle_module=dill)
    fields = torch.load(path/"fields.pkl", pickle_module=dill)
    return Dataset(examples, fields)

Not that actual objects could be a bit different, for example, if you save TabularDataset, then load_dataset returns an instance of class Dataset. This unlikely affect the data pipeline but may require extra diligence for tests. In the case of a custom tokenizer, it should be serializable as well (e.g. no lambda functions, etc).

Nikita
  • 333
  • 3
  • 8
1

You can use dill instead of pickle. It works for me. You can save a torchtext Field like

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True,fix_length=200,batch_first=True)
with open("model/TEXT.Field","wb")as f:
     dill.dump(TEXT,f)

And load a Field like

with open("model/TEXT.Field","rb")as f:
     TEXT=dill.load(f)

The offical code suppport is under development,you can follow https://github.com/pytorch/text/issues/451 and https://github.com/pytorch/text/issues/73 .

chj
  • 26
  • 3
1

You can always use the pickle to dump the objects, but keep in mind one thing that dumping a list of dictionary or fields objects are not taken care of by the module, so to the best try to decompose the list first

To Store the DataSet Object to a pickle file for later easy loading

def save_to_pickle(dataSetObject,PATH):
    with open(PATH,'wb') as output:
        for i in dataSetObject:
            pickle.dump(vars(i), output, pickle.HIGHEST_PROTOCOL)

The toughest thing is yet to come, Yeah loading the pickle file.... ;)

First, try to look for all field names and field attributes and then go for the kill

To load the pickle file into the DataSetObject

def load_pickle(PATH, FIELDNAMES, FIELD):
    dataList = []
    with open(PATH, "rb") as input_file:
        while True:
            try:
                # Taking the dictionary instance as the input Instance
                inputInstance = pickle.load(input_file)
                # plugging it into the list
                dataInstance =  [inputInstance[FIELDNAMES[0]],inputInstance[FIELDNAMES[1]]]
                # Finally creating an example objects list
                dataList.append(Example().fromlist(dataInstance,fields=FIELD))
            except EOFError:
                break

    # At last creating a data Set Object
    exampleListObject = Dataset(dataList, fields=data_fields)
    return exampleListObject 

This hackish solution has worked in my case, hope you will find it useful in your case too.

Btw any suggestion is welcome :).

Atul Sahay
  • 11
  • 2
0

The pickle/dill approach is fine if your dataset is small. But if you are working with large datasets I won't recommend it as it will be too slow.

I simply save the examples (iteratively) as JSON-strings. The reason behind this is because saving the whole Dataset object takes a lot of time, plus you need serialization tricks such a dill, which makes the serialization even slower.

Moreover, these serializers take a lot of memory (some of them even create copies of the dataset) and if they start making use of the swap memory, you're done. That process is gonna take so long that you will probably terminate it before it finishes.

Therefore, I end up with the following approach:

  1. Iterate over the examples
  2. Convert each example (on the fly) to a JSON-string
  3. Write that JSON-string into a text file (one sample per line)
  4. When loading, add the examples to the Dataset object along with the fields
def save_examples(dataset, savepath):
    with open(savepath, 'w') as f:
        # Save num. elements (not really need it)
        f.write(json.dumps(total))  # Write examples length
        f.write("\n")

        # Save elements
        for pair in dataset.examples:
            data = [pair.src, pair.trg]
            f.write(json.dumps(data))  # Write samples
            f.write("\n")


def load_examples(filename):
    examples = []
    with open(filename, 'r') as f:
        # Read num. elements (not really need it)
        total = json.loads(f.readline())

        # Save elements
        for i in range(total):
            line = f.readline()
            example = json.loads(line)
            # example = data.Example().fromlist(example, fields)  # Create Example obj. (you can do it here or later)
            examples.append(example)

    end = time.time()
    print(end - start)
    return examples

Then, you can simply rebuild the dataset by:

# Define fields
SRC = data.Field(...)
TRG = data.Field(...)
fields = [('src', SRC), ('trg', TRG)]

# Load examples from JSON and convert them to "Example objects"
examples = load_examples(filename)
examples = [data.Example().fromlist(d, fields) for d in examples]

# Build dataset
mydataset = Dataset(examples, fields)

The reason why I use JSON instead of pickle, dill, msgpack, etc is not arbitrary.

I did some tests and these are the results:

Dataset size: 2x (1,960,641)

Saving times:
- Pickle/Dill*: >30-45 min (...or froze my computer)

- MessagePack (iterative): 123.44 sec
  100%|██████████| 1960641/1960641 [02:03<00:00, 15906.52it/s]

- JSON (iterative): 16.33 sec
  100%|██████████| 1960641/1960641 [00:15<00:00, 125955.90it/s]

- JSON (bulk): 46.54 sec (memory problems)

Loading times:
 - Pickle/Dill*: -

 - MessagePack (iterative): 143.79 sec
   100%|██████████| 1960641/1960641 [02:23<00:00, 13635.20it/s]

 - JSON (iterative): 33.83 sec
   100%|██████████| 1960641/1960641 [00:33<00:00, 57956.28it/s] 

 - JSON (bulk): 27.43 sec

*Similar approach as the other answers

Salva Carrión
  • 510
  • 6
  • 16