I have problem in using DQN for a diagonal line and sin wave as price fluctuation. When the price goes up, there will be a reward and it is colored as green in the chart. When the price goes down and it is tagged as red, the reward goes up. Please see this link The link's DQN is very good in learning than the stable baselines's DQN.
I am having a difficulty even using diagonal line for DQN.
Sin wave: This would be good if the result is the opposite. Green for rising and red for descending.
What I did was change the learning rates from .01 to 10. Epsilon to 1.
In PPO2, I can get a good result. For sin Wave:
model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
model.learn(total_timesteps=500000)
For the diagonal line it did work too!
This is my code. Just comment and un comment things you needed to test PPO2 vs DQN
from copy import deepcopy
import numpy as np
import pandas as pd
import gym
import gym_anytrading
from stable_baselines import A2C , DQN ,ACKTR
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy
import matplotlib.pyplot as plt
import math as m
from stable_baselines.deepq.policies import FeedForwardPolicy
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2
class CustomDQNPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomDQNPolicy, self).__init__(*args, **kwargs,
layers=[64,64,64],
layer_norm=True,
feature_extraction="mlp")
def main():
n_cpu = 16
# df = gym_anytrading.datasets.STOCKS_GOOGL.copy()
# print(df)
# arraysin =[]
# for x in range(0,200,1):
# arraysin = np.append(arraysin,(m.sin(x/10)+1))
# print(arraysin)
arraysin = np.arange(200/10.0) #linearly increasing prices
df = pd.DataFrame(arraysin)
# # convert the column (it's a string) to datetime type
# datetime_series = pd.to_datetime(df['date_of_birth'])
# # create datetime index passing the datetime series
# datetime_index = pd.DatetimeIndex(datetime_series.values)
df = pd.DataFrame(arraysin)
print(df)
df.columns=['Close']
# df=df.set_index(datetime_index)
window_size = 1
print(df)
start_index = window_size
end_index = len(df)
env_maker = lambda: gym.make(
'stocks-v0',
df = df,
window_size = window_size,
frame_bound = (start_index, end_index)
)
print(df)
env = DummyVecEnv([env_maker for _ in range(n_cpu)])
# policy_kwargs = dict(net_arch=[64, 'lstm', dict(vf=[128, 128, 128], pi=[64, 64])])
# model = A2C('MlpLstmPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
# model = A2C(MlpPolicy, env, verbose=1,learning_rate=.01)
# model = ACKTR(MlpPolicy, env, verbose=1,learning_rate=1)
model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
# model = DQN(policy=CustomDQNPolicy,env=env, verbose=1,
# learning_rate= .01,
# buffer_size= 10000,
# double_q = False,
# exploration_final_eps= 1,
# prioritized_replay= True)
model.learn(total_timesteps=100000)
# model.save('nzdusdDQN')
env = env_maker()
observation = env.reset()
while True:
# observation = observation[np.newaxis, ...]
# action = env.action_space.sample()
action, _states = model.predict(observation)
observation, reward, done, info = env.step(action)
# env.render()
if done:
print("info:", info)
break
# for e in env.envs:
# plt.figure(figsize=(16, 6))
# e.render_all()
# plt.show()
plt.figure(figsize=(16, 6))
env.render_all()
plt.show()
if __name__ == '__main__':
main()
System Info: Describe the characteristic of your environment:
- Windows 10
- tensorflow 1.15.0
- stable-baselines 2.10.2a0 dev_0
- gym-anytrading 1.2.0
Conda list:
PS E:\ML\reinforcementlearning\tradeorig> conda list
# packages in environment at C:\anaconda\envs\gymorig:
#
# Name Version Build Channel
_tflow_select 2.2.0 eigen
absl-py 0.11.0 py37haa95532_0
alabaster 0.7.12 py37_0
apipkg 1.5 pypi_0 pypi
argh 0.26.2 py37_0
asn1crypto 1.4.0 py_0
astor 0.8.1 py37_0
astroid 2.4.2 py37_0
async_generator 1.10 py37h28b3542_0
atari-py 0.2.6 pypi_0 pypi
atomicwrites 1.4.0 py_0
attrs 20.2.0 py_0
autopep8 1.5.4 py_0
babel 2.8.0 py_0
backcall 0.2.0 py_0
bcrypt 3.2.0 py37he774522_0
blas 1.0 mkl
bleach 3.2.1 py_0
brotlipy 0.7.0 py37he774522_1000
ca-certificates 2020.10.14 0
certifi 2020.6.20 py37haa95532_2
cffi 1.14.3 py37h7a1dbc1_0
chardet 3.0.4 py37_1003
cloudpickle 1.6.0 py_0
colorama 0.4.4 py_0
coverage 5.3 pypi_0 pypi
cryptography 2.3.1 py37h74b6da3_0
cycler 0.10.0 pypi_0 pypi
decorator 4.4.2 py_0
defusedxml 0.6.0 py_0
diff-match-patch 20200713 py_0
docutils 0.16 py37_1
entrypoints 0.3 py37_0
execnet 1.7.1 pypi_0 pypi
flake8 3.8.4 py_0
future 0.18.2 py37_1
gast 0.2.2 py37_0
google-pasta 0.2.0 py_0
grpcio 1.14.1 py37h5c4b210_0
gym 0.17.3 pypi_0 pypi
gym-anytrading 1.2.0 pypi_0 pypi
h5py 2.10.0 py37h5e291fa_0
hdf5 1.10.4 h7ebc959_0
icc_rt 2019.0.0 h0cc432a_1
icu 58.2 ha925a31_3
idna 2.10 py_0
imagesize 1.2.0 py_0
importlab 0.5.1 pypi_0 pypi
importlib-metadata 2.0.0 py_1
importlib_metadata 2.0.0 1
iniconfig 1.0.1 pypi_0 pypi
intel-openmp 2020.2 254
intervaltree 3.1.0 py_0
ipykernel 5.3.4 py37h5ca1d4c_0
ipython 7.18.1 py37h5ca1d4c_0
ipython_genutils 0.2.0 py37_0
isort 5.6.4 py_0
jedi 0.17.1 py37_0
jinja2 2.11.2 py_0
joblib 0.17.0 pypi_0 pypi
jpeg 9b hb83a4c4_2
jsonschema 3.2.0 py_2
jupyter_client 6.1.7 py_0
jupyter_core 4.6.3 py37_0
jupyterlab_pygments 0.1.2 py_0
keras-applications 1.0.8 py_1
keras-base 2.3.1 py37_0
keras-preprocessing 1.1.0 py_1
keyring 21.4.0 py37_1
kiwisolver 1.2.0 pypi_0 pypi
lazy-object-proxy 1.4.3 py37he774522_0
libpng 1.6.37 h2a8f88b_0
libprotobuf 3.13.0.1 h200bbdf_0
libsodium 1.0.18 h62dcd97_0
libspatialindex 1.9.3 h33f27b4_0
livereload 2.6.3 pypi_0 pypi
lxml 4.5.2 pypi_0 pypi
markdown 3.3.2 py37_0
markupsafe 1.1.1 py37hfa6e2cd_1
matplotlib 3.3.2 pypi_0 pypi
mccabe 0.6.1 py37_1
mistune 0.8.4 py37hfa6e2cd_1001
mkl 2020.2 256
mkl-service 2.3.0 py37hb782905_0
mkl_fft 1.2.0 py37h45dec08_0
mkl_random 1.1.1 py37h47e9c7a_0
mpi4py 3.0.3 pypi_0 pypi
msgpack 1.0.0 pypi_0 pypi
multitasking 0.0.9 pypi_0 pypi
nbclient 0.5.1 py_0
nbconvert 6.0.7 py37_0
nbformat 5.0.8 py_0
nest-asyncio 1.4.1 py_0
networkx 2.5 pypi_0 pypi
ninja 1.10.0.post2 pypi_0 pypi
numpy 1.19.2 py37hadc3359_0
numpy-base 1.19.2 py37ha3acd2a_0
numpydoc 1.1.0 py_0
opencv-python 4.4.0.44 pypi_0 pypi
openssl 1.0.2u he774522_0
opt_einsum 3.1.0 py_0
packaging 20.4 py_0
pandas 1.1.3 py37ha925a31_0
pandoc 2.11 h9490d1a_0
pandocfilters 1.4.2 py37_1
paramiko 2.4.2 py37_0
parso 0.7.0 py_0
pathtools 0.1.2 py_1
pexpect 4.8.0 py37_1
pickleshare 0.7.5 py37_1001
pillow 7.2.0 pypi_0 pypi
pip 20.2.4 py37_0
pluggy 0.13.1 py37_0
prompt-toolkit 3.0.8 py_0
protobuf 3.13.0.1 py37ha925a31_1
psutil 5.7.2 py37he774522_0
py 1.9.0 pypi_0 pypi
pyasn1 0.4.8 py_0
pycodestyle 2.6.0 py_0
pycparser 2.20 py_2
pydocstyle 5.1.1 py_0
pyflakes 2.2.0 py_0
pyglet 1.5.0 pypi_0 pypi
pygments 2.7.1 py_0
pylint 2.6.0 py37_0
pynacl 1.4.0 py37h62dcd97_1
pyopenssl 19.0.0 py37_0
pyparsing 2.4.7 py_0
pyqt 5.6.0 py37ha878b3d_6
pyreadline 2.1 py37_1
pyrsistent 0.17.3 py37he774522_0
pysocks 1.7.1 py37_1
pytest 6.1.1 pypi_0 pypi
pytest-cov 2.10.1 pypi_0 pypi
pytest-env 0.6.2 pypi_0 pypi
pytest-forked 1.3.0 pypi_0 pypi
pytest-xdist 2.1.0 pypi_0 pypi
python 3.7.1 h33f27b4_4
python-dateutil 2.8.1 py_0
python-jsonrpc-server 0.4.0 py_0
python-language-server 0.35.1 py_0
pytype 2020.9.29 pypi_0 pypi
pytz 2020.1 py_0
pywin32 227 py37he774522_1
pywin32-ctypes 0.2.0 py37_1001
pyyaml 5.3.1 pypi_0 pypi
pyzmq 19.0.2 py37ha925a31_1
qdarkstyle 2.8.1 py_0
qt 5.6.2 vc14h6f8c307_12
qtawesome 1.0.1 py_0
qtconsole 4.7.7 py_0
qtpy 1.9.0 py_0
quantstats 0.0.25 pypi_0 pypi
requests 2.24.0 py_0
rope 0.18.0 py_0
rtree 0.9.4 py37h21ff451_1
ruamel-yaml 0.16.12 pypi_0 pypi
ruamel-yaml-clib 0.2.2 pypi_0 pypi
scipy 1.5.2 py37h9439919_0
seaborn 0.11.0 pypi_0 pypi
setuptools 50.3.0 py37h9490d1a_1
sip 4.18.1 py37h6538335_2
six 1.15.0 py_0
snowballstemmer 2.0.0 py_0
sortedcontainers 2.2.2 py_0
sphinx 3.2.1 py_0
sphinx-autobuild 2020.9.1 pypi_0 pypi
sphinx-rtd-theme 0.5.0 pypi_0 pypi
sphinxcontrib-applehelp 1.0.2 py_0
sphinxcontrib-devhelp 1.0.2 py_0
sphinxcontrib-htmlhelp 1.0.3 py_0
sphinxcontrib-jsmath 1.0.1 py_0
sphinxcontrib-qthelp 1.0.3 py_0
sphinxcontrib-serializinghtml 1.1.4 py_0
spyder 4.1.5 py37_0
spyder-kernels 1.9.4 py37_0
sqlite 3.33.0 h2a8f88b_0
stable-baselines 2.10.2a0 dev_0 <develop>
tabulate 0.8.7 pypi_0 pypi
tensorboard 2.0.0 pyhb38c66f_1
tensorflow 1.15.0 eigen_py37h9f89a44_0
tensorflow-base 1.15.0 eigen_py37h07d2309_0
tensorflow-estimator 1.15.1 pyh2649769_0
termcolor 1.1.0 py37_1
testpath 0.4.4 py_0
toml 0.10.1 py_0
tornado 6.0.4 py37he774522_1
traitlets 5.0.5 py_0
typed-ast 1.4.1 py37he774522_0
ujson 4.0.1 py37ha925a31_0
urllib3 1.25.11 py_0
vc 14.1 h0510ff6_4
vs2015_runtime 14.16.27012 hf0eaf9b_3
watchdog 0.10.3 py37_0
wcwidth 0.2.5 py_0
webencodings 0.5.1 py37_1
werkzeug 0.16.1 py_0
wheel 0.35.1 py_0
win_inet_pton 1.1.0 py37_0
wincertstore 0.2 py37_0
wrapt 1.11.2 py37he774522_0
yaml 0.2.5 he774522_0
yapf 0.30.0 py_0
yfinance 0.1.55 pypi_0 pypi
zeromq 4.3.2 ha925a31_3
zipp 3.3.1 py_0
zlib 1.2.11 h62dcd97_4