1

When using Huggingface Tokenizer with return_overflowing_tokens=True, the results can have multiple token sequence per input string. Therefore, when doing a Dataset.map from strings to token sequence, you need to remove the original columns (as they are not 1:1).

For my application, I need to continue to reference the original dataset's columns. How can I copy them over to the tokenized dataset?

For example:

# Pseudocode
ds['txt'] == ['The quick brown fox', 'jumped over the lazy hens']
ds['src'] == ['Nursery rhyme 1', 'Nursery rhyme 2']

tokenize(ds['txt'], return_overflowing_tokens=True) =>
[tokens for 'The quick brown'],
[tokens for 'fox'],
[tokens for 'jumped over'],
[tokens for 'the lazy hens'],

# I'd like a tokenized_ds to look like this:
tokenized_ds[0] = {txt: 'The quick brown fox', src: 'Nursery rhyme 1', tokens: [tokens for 'The quick brown']}
tokenized_ds[1] = {txt: 'The quick brown fox', src: 'Nursery rhyme 1', tokens: [tokens for 'fox']}


Some clarifications:

  1. Some of the columns that need to be preserved are strings. These are harder to preserve when you set the format to a Tensor.

  2. The dataset will be batched via Dataloader, and only batches handed off for processing. This means that the original, full dataset will not necessarily be available. That makes it hard to map back to the original dataset on demand, which is why I want to preserve the columns within the transformed dataset.

SRobertJames
  • 8,210
  • 14
  • 60
  • 107

1 Answers1

3

You can achieve that with the parameter batched of the map function:

from datasets import Dataset
from transformers import RobertaTokenizer

sample = {'txt': ['The quick brown fox', 'jumped over the lazy hens'], 'src': ['Nursery rhyme 1', 'Nursery rhyme 2']}

ds = Dataset.from_dict(sample)
t = RobertaTokenizer.from_pretrained('roberta-base')

print('ds before map')
for x in ds:
  print(x)

def srobertjames_fn(samples):
  encoded = t(samples['txt'], truncation=True, max_length=5, return_overflowing_tokens=True)
  bla = {k:v*2  for k,v in samples.items()}
  bla.update({'tokens': encoded.overflowing_tokens + encoded.input_ids })
  return bla 

ds2 = ds.map(srobertjames_fn, batched=True)

print('ds after map')
for x in ds2:
  print(x)

Output:

ds before map
{'txt': 'The quick brown fox', 'src': 'Nursery rhyme 1'}
{'txt': 'jumped over the lazy hens', 'src': 'Nursery rhyme 2'}
ds after map
{'txt': 'The quick brown fox', 'src': 'Nursery rhyme 1', 'tokens': [23602]}
{'txt': 'jumped over the lazy hens', 'src': 'Nursery rhyme 2', 'tokens': [5, 22414, 37, 6852]}
{'txt': 'The quick brown fox', 'src': 'Nursery rhyme 1', 'tokens': [0, 133, 2119, 6219, 2]}
{'txt': 'jumped over the lazy hens', 'src': 'Nursery rhyme 2', 'tokens': [0, 267, 25844, 81, 2]}

Please note, that a short string in the txt-column will result in no overflowing tokens (i.e. an empty list). You can remove such rows with the following filter:

ds3 = ds2.filter(lambda x: len(x["tokens"]) != 0)
cronoik
  • 15,434
  • 3
  • 40
  • 78