0

I'm using mxnet to do deep reinforcement learning. I have a simple generator that yields observations from a random walk through a game (from openai gym):

import mxnet as mx
from mxnet import *
from mxnet.ndarray import *

import gym

def random_walk(env_id):
    env, done = gym.make(env_id), True
    min_rew, max_rew = env.reward_range
    while True:
        if done:
            obs = env.reset()
        action = env.action_space.sample()
        obs, rew, done, info = env.step(action)

        # some preprocessing ommited...

        yield obs, rew, action # all converted to ndarrays now

I want to be able to save this data to a big file containing rows of (observation, reward, action), so I can later easily load, shuffle, and batch them with mxnet.

Is it possible to do using mxnet if so, how?

matan tsuberi
  • 145
  • 1
  • 12

1 Answers1

0

It doesn't sound like an MXNet task to do the saving. If you can serialize observation, reward and action to string, then you can use regular python to create and save data to file https://docs.python.org/3.7/tutorial/inputoutput.html#reading-and-writing-files

To do shuffling and batching in mxnet, you load your file first using regular python into a python list and then create SimpleDataset or ArrayDataset, depending if you have separate Y list or not. Then you pass the dataset object to DataLoader which can do shuffling and batching for you.

Take a look here for a full example: https://mxnet.incubator.apache.org/tutorials/gluon/datasets.html

Sergei
  • 1,617
  • 15
  • 31
  • The file can be really big and I don't want to ever hold it all on memory. Also, saving ndarrays as a big string doesn't seem like the best way. – matan tsuberi Jul 20 '18 at 22:09