6

I'm trying to set up a generalized Reinforcement Learning framework in PyTorch to take advantage of all the high-level utilities out there which leverage PyTorch DataSet and DataLoader, like Ignite or FastAI, but I've hit a blocker with the dynamic nature of Reinforcement Learning data:

  • Data Items are generated from code, not read from a file, and they are dependent on previous actions and model results, therefore each nextItem call needs access to the model state.
  • Training episodes are not fixed length so I need a dynamic batch size as well as a dynamic total data set size. My preference would be to use a terminating condition function instead of a number. I could "possibly" do this with padding, as in NLP sentence processing, but that's a real hack.

My Google and StackOverflow searches so far have yielded zilch. Anyone here know of existing solutions or workarounds to using DataLoader or DataSet with Reinforcement Learning? I hate to loose access to all the existing libraries out there which depend on those.

Ken Otwell
  • 345
  • 3
  • 13

1 Answers1

4

Here is one PyTorch-based framework and here is something from Facebook.

When it comes to your question (and noble quest, no doubt):

You could easily create a torch.utils.data.Dataset dependent on anything, including the model, something like this (pardon weak abstraction, it's just to prove a point):

import typing

import torch
from torch.utils.data import Dataset


class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
        self.current_state = initial_state
        self.actor: torch.nn.Module = actor
        self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
        self.current_state = self.actor.update(self.current_state)
        return self.current_state.get_data()

    def __len__(self):
        return self.max_interactions

Assuming, torch.nn.Module-like network has some kind of update changing state of the environment. All in all it's just a Python structure and so you could model a lot of things with it.

You could specify max_interactions to be almost infinite or you could change it on the fly if needed with some callbacks during training (as __len__ will be called multiple times throughout the code probably). Environment could furthermore provide batches instead of samples.

torch.utils.data.DataLoader has batch_sampler argument, there you could generate batches of varying length. As the network is not dependent on the first dimension, you could return any batch size you want from there as well.

BTW. Padding should be used if each sample would be of different length, varying batch size has nothing to do with that.

Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
  • Thanks, Szymon - this is a decent approach. Kinda of a hack, given that we don't really know how or where __len__ is called (is it in a for range?) But It's probably the best we can do. But especially thanks for the SLM link - that looks like really good work. I'm going to spend some time with it to make sure I'm not reinventing the wheel. – Ken Otwell Jul 30 '19 at 03:31