I'm using mxnet to do deep reinforcement learning. I have a simple generator that yields observations from a random walk through a game (from openai gym):
import mxnet as mx
from mxnet import *
from mxnet.ndarray import *
import gym
def random_walk(env_id):
env, done = gym.make(env_id), True
min_rew, max_rew = env.reward_range
while True:
if done:
obs = env.reset()
action = env.action_space.sample()
obs, rew, done, info = env.step(action)
# some preprocessing ommited...
yield obs, rew, action # all converted to ndarrays now
I want to be able to save this data to a big file containing rows of (observation, reward, action)
, so I can later easily load, shuffle, and batch them with mxnet
.
Is it possible to do using mxnet
if so, how?