2

I have problem in using DQN for a diagonal line and sin wave as price fluctuation. When the price goes up, there will be a reward and it is colored as green in the chart. When the price goes down and it is tagged as red, the reward goes up. Please see this link The link's DQN is very good in learning than the stable baselines's DQN.

I am having a difficulty even using diagonal line for DQN.

DQN diagonal line

Sin wave: This would be good if the result is the opposite. Green for rising and red for descending.

Sin Wave line

What I did was change the learning rates from .01 to 10. Epsilon to 1.

In PPO2, I can get a good result. For sin Wave:

model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
model.learn(total_timesteps=500000)

Sin Wave line using PPO2

For the diagonal line it did work too!

Diagonal Line

This is my code. Just comment and un comment things you needed to test PPO2 vs DQN

from copy import deepcopy
import numpy as np
import pandas as pd

import gym
import gym_anytrading


from stable_baselines import A2C , DQN ,ACKTR
from stable_baselines.common.vec_env import DummyVecEnv 
from stable_baselines.deepq.policies import MlpPolicy
import matplotlib.pyplot as plt
import math as m
from stable_baselines.deepq.policies import FeedForwardPolicy


from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2

class CustomDQNPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomDQNPolicy, self).__init__(*args, **kwargs,
                                              layers=[64,64,64],
                                              layer_norm=True,
                                              feature_extraction="mlp")

def main():
    n_cpu = 16    
    # df = gym_anytrading.datasets.STOCKS_GOOGL.copy()
    # print(df)
    
    
    # arraysin =[]
    # for x in range(0,200,1):
    #     arraysin = np.append(arraysin,(m.sin(x/10)+1))
        
    
    # print(arraysin)
    
    arraysin = np.arange(200/10.0) #linearly increasing prices
    
    df = pd.DataFrame(arraysin)
    
    
    
# # convert the column (it's a string) to datetime type
#     datetime_series = pd.to_datetime(df['date_of_birth'])

# # create datetime index passing the datetime series
#     datetime_index = pd.DatetimeIndex(datetime_series.values) 
    df = pd.DataFrame(arraysin)
    print(df)
    df.columns=['Close']
    # df=df.set_index(datetime_index)
    window_size = 1
    print(df)
    start_index = window_size
    end_index = len(df)

    env_maker = lambda: gym.make(
        'stocks-v0',
        df = df,
        window_size = window_size,
        frame_bound = (start_index, end_index)
    )
    print(df) 
    env = DummyVecEnv([env_maker for _ in range(n_cpu)])

    # policy_kwargs = dict(net_arch=[64, 'lstm', dict(vf=[128, 128, 128], pi=[64, 64])])
    # model = A2C('MlpLstmPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
    
    
    # model = A2C(MlpPolicy, env, verbose=1,learning_rate=.01)
    # model = ACKTR(MlpPolicy, env, verbose=1,learning_rate=1)
    model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
    
    # model = DQN(policy=CustomDQNPolicy,env=env, verbose=1,
    #         learning_rate= .01,
    #         buffer_size= 10000,
    #         double_q = False,
    #         exploration_final_eps= 1,
    #         prioritized_replay= True)

    model.learn(total_timesteps=100000)
    # model.save('nzdusdDQN') 
    env = env_maker()
    observation = env.reset()

    while True:
        # observation = observation[np.newaxis, ...]

        # action = env.action_space.sample()
        action, _states = model.predict(observation)
        observation, reward, done, info = env.step(action)

        # env.render()
        if done:
            print("info:", info)
            break

    # for e in env.envs:
    #     plt.figure(figsize=(16, 6))
    #     e.render_all()
    #     plt.show()
    plt.figure(figsize=(16, 6))
    env.render_all()
    plt.show()    


if __name__ == '__main__':
    main()

System Info: Describe the characteristic of your environment:

  1. Windows 10
  2. tensorflow 1.15.0
  3. stable-baselines 2.10.2a0 dev_0
  4. gym-anytrading 1.2.0

Conda list:

PS E:\ML\reinforcementlearning\tradeorig> conda list
# packages in environment at C:\anaconda\envs\gymorig:
#
# Name                    Version                   Build  Channel
_tflow_select             2.2.0                     eigen         
absl-py                   0.11.0           py37haa95532_0         
alabaster                 0.7.12                   py37_0         
apipkg                    1.5                      pypi_0    pypi 
argh                      0.26.2                   py37_0         
asn1crypto                1.4.0                      py_0         
astor                     0.8.1                    py37_0         
astroid                   2.4.2                    py37_0         
async_generator           1.10             py37h28b3542_0         
atari-py                  0.2.6                    pypi_0    pypi
atomicwrites              1.4.0                      py_0
attrs                     20.2.0                     py_0
autopep8                  1.5.4                      py_0
babel                     2.8.0                      py_0
backcall                  0.2.0                      py_0
bcrypt                    3.2.0            py37he774522_0
blas                      1.0                         mkl
bleach                    3.2.1                      py_0
brotlipy                  0.7.0           py37he774522_1000
ca-certificates           2020.10.14                    0
certifi                   2020.6.20        py37haa95532_2
cffi                      1.14.3           py37h7a1dbc1_0
chardet                   3.0.4                 py37_1003
cloudpickle               1.6.0                      py_0
colorama                  0.4.4                      py_0
coverage                  5.3                      pypi_0    pypi
cryptography              2.3.1            py37h74b6da3_0
cycler                    0.10.0                   pypi_0    pypi
decorator                 4.4.2                      py_0
defusedxml                0.6.0                      py_0
diff-match-patch          20200713                   py_0
docutils                  0.16                     py37_1
entrypoints               0.3                      py37_0
execnet                   1.7.1                    pypi_0    pypi
flake8                    3.8.4                      py_0
future                    0.18.2                   py37_1
gast                      0.2.2                    py37_0
google-pasta              0.2.0                      py_0
grpcio                    1.14.1           py37h5c4b210_0
gym                       0.17.3                   pypi_0    pypi
gym-anytrading            1.2.0                    pypi_0    pypi
h5py                      2.10.0           py37h5e291fa_0
hdf5                      1.10.4               h7ebc959_0
icc_rt                    2019.0.0             h0cc432a_1
icu                       58.2                 ha925a31_3
idna                      2.10                       py_0
imagesize                 1.2.0                      py_0
importlab                 0.5.1                    pypi_0    pypi
importlib-metadata        2.0.0                      py_1
importlib_metadata        2.0.0                         1
iniconfig                 1.0.1                    pypi_0    pypi
intel-openmp              2020.2                      254
intervaltree              3.1.0                      py_0
ipykernel                 5.3.4            py37h5ca1d4c_0
ipython                   7.18.1           py37h5ca1d4c_0
ipython_genutils          0.2.0                    py37_0
isort                     5.6.4                      py_0
jedi                      0.17.1                   py37_0
jinja2                    2.11.2                     py_0
joblib                    0.17.0                   pypi_0    pypi
jpeg                      9b                   hb83a4c4_2
jsonschema                3.2.0                      py_2
jupyter_client            6.1.7                      py_0
jupyter_core              4.6.3                    py37_0
jupyterlab_pygments       0.1.2                      py_0
keras-applications        1.0.8                      py_1
keras-base                2.3.1                    py37_0
keras-preprocessing       1.1.0                      py_1
keyring                   21.4.0                   py37_1
kiwisolver                1.2.0                    pypi_0    pypi
lazy-object-proxy         1.4.3            py37he774522_0
libpng                    1.6.37               h2a8f88b_0
libprotobuf               3.13.0.1             h200bbdf_0
libsodium                 1.0.18               h62dcd97_0
libspatialindex           1.9.3                h33f27b4_0
livereload                2.6.3                    pypi_0    pypi
lxml                      4.5.2                    pypi_0    pypi
markdown                  3.3.2                    py37_0
markupsafe                1.1.1            py37hfa6e2cd_1
matplotlib                3.3.2                    pypi_0    pypi
mccabe                    0.6.1                    py37_1
mistune                   0.8.4           py37hfa6e2cd_1001
mkl                       2020.2                      256
mkl-service               2.3.0            py37hb782905_0
mkl_fft                   1.2.0            py37h45dec08_0
mkl_random                1.1.1            py37h47e9c7a_0
mpi4py                    3.0.3                    pypi_0    pypi
msgpack                   1.0.0                    pypi_0    pypi
multitasking              0.0.9                    pypi_0    pypi
nbclient                  0.5.1                      py_0
nbconvert                 6.0.7                    py37_0
nbformat                  5.0.8                      py_0
nest-asyncio              1.4.1                      py_0
networkx                  2.5                      pypi_0    pypi
ninja                     1.10.0.post2             pypi_0    pypi
numpy                     1.19.2           py37hadc3359_0
numpy-base                1.19.2           py37ha3acd2a_0
numpydoc                  1.1.0                      py_0
opencv-python             4.4.0.44                 pypi_0    pypi
openssl                   1.0.2u               he774522_0
opt_einsum                3.1.0                      py_0
packaging                 20.4                       py_0
pandas                    1.1.3            py37ha925a31_0
pandoc                    2.11                 h9490d1a_0
pandocfilters             1.4.2                    py37_1
paramiko                  2.4.2                    py37_0
parso                     0.7.0                      py_0
pathtools                 0.1.2                      py_1
pexpect                   4.8.0                    py37_1
pickleshare               0.7.5                 py37_1001
pillow                    7.2.0                    pypi_0    pypi
pip                       20.2.4                   py37_0
pluggy                    0.13.1                   py37_0
prompt-toolkit            3.0.8                      py_0
protobuf                  3.13.0.1         py37ha925a31_1
psutil                    5.7.2            py37he774522_0
py                        1.9.0                    pypi_0    pypi
pyasn1                    0.4.8                      py_0
pycodestyle               2.6.0                      py_0
pycparser                 2.20                       py_2
pydocstyle                5.1.1                      py_0
pyflakes                  2.2.0                      py_0
pyglet                    1.5.0                    pypi_0    pypi
pygments                  2.7.1                      py_0
pylint                    2.6.0                    py37_0
pynacl                    1.4.0            py37h62dcd97_1
pyopenssl                 19.0.0                   py37_0
pyparsing                 2.4.7                      py_0
pyqt                      5.6.0            py37ha878b3d_6
pyreadline                2.1                      py37_1
pyrsistent                0.17.3           py37he774522_0
pysocks                   1.7.1                    py37_1
pytest                    6.1.1                    pypi_0    pypi
pytest-cov                2.10.1                   pypi_0    pypi
pytest-env                0.6.2                    pypi_0    pypi
pytest-forked             1.3.0                    pypi_0    pypi
pytest-xdist              2.1.0                    pypi_0    pypi
python                    3.7.1                h33f27b4_4
python-dateutil           2.8.1                      py_0
python-jsonrpc-server     0.4.0                      py_0
python-language-server    0.35.1                     py_0
pytype                    2020.9.29                pypi_0    pypi
pytz                      2020.1                     py_0
pywin32                   227              py37he774522_1
pywin32-ctypes            0.2.0                 py37_1001
pyyaml                    5.3.1                    pypi_0    pypi
pyzmq                     19.0.2           py37ha925a31_1
qdarkstyle                2.8.1                      py_0
qt                        5.6.2           vc14h6f8c307_12
qtawesome                 1.0.1                      py_0
qtconsole                 4.7.7                      py_0
qtpy                      1.9.0                      py_0
quantstats                0.0.25                   pypi_0    pypi
requests                  2.24.0                     py_0
rope                      0.18.0                     py_0
rtree                     0.9.4            py37h21ff451_1
ruamel-yaml               0.16.12                  pypi_0    pypi
ruamel-yaml-clib          0.2.2                    pypi_0    pypi
scipy                     1.5.2            py37h9439919_0
seaborn                   0.11.0                   pypi_0    pypi
setuptools                50.3.0           py37h9490d1a_1
sip                       4.18.1           py37h6538335_2
six                       1.15.0                     py_0
snowballstemmer           2.0.0                      py_0
sortedcontainers          2.2.2                      py_0
sphinx                    3.2.1                      py_0
sphinx-autobuild          2020.9.1                 pypi_0    pypi
sphinx-rtd-theme          0.5.0                    pypi_0    pypi
sphinxcontrib-applehelp   1.0.2                      py_0
sphinxcontrib-devhelp     1.0.2                      py_0
sphinxcontrib-htmlhelp    1.0.3                      py_0
sphinxcontrib-jsmath      1.0.1                      py_0
sphinxcontrib-qthelp      1.0.3                      py_0
sphinxcontrib-serializinghtml 1.1.4                      py_0
spyder                    4.1.5                    py37_0
spyder-kernels            1.9.4                    py37_0
sqlite                    3.33.0               h2a8f88b_0
stable-baselines          2.10.2a0                  dev_0    <develop>
tabulate                  0.8.7                    pypi_0    pypi
tensorboard               2.0.0              pyhb38c66f_1
tensorflow                1.15.0          eigen_py37h9f89a44_0
tensorflow-base           1.15.0          eigen_py37h07d2309_0
tensorflow-estimator      1.15.1             pyh2649769_0
termcolor                 1.1.0                    py37_1
testpath                  0.4.4                      py_0
toml                      0.10.1                     py_0
tornado                   6.0.4            py37he774522_1
traitlets                 5.0.5                      py_0
typed-ast                 1.4.1            py37he774522_0
ujson                     4.0.1            py37ha925a31_0
urllib3                   1.25.11                    py_0
vc                        14.1                 h0510ff6_4
vs2015_runtime            14.16.27012          hf0eaf9b_3
watchdog                  0.10.3                   py37_0
wcwidth                   0.2.5                      py_0
webencodings              0.5.1                    py37_1
werkzeug                  0.16.1                     py_0
wheel                     0.35.1                     py_0
win_inet_pton             1.1.0                    py37_0
wincertstore              0.2                      py37_0
wrapt                     1.11.2           py37he774522_0
yaml                      0.2.5                he774522_0
yapf                      0.30.0                     py_0
yfinance                  0.1.55                   pypi_0    pypi
zeromq                    4.3.2                ha925a31_3
zipp                      3.3.1                      py_0
zlib                      1.2.11               h62dcd97_4
toksis
  • 139
  • 1
  • 4
  • 14

1 Answers1

0

I think that the problem is that you used default structure of network in stable-baselines. You can see in your example:

model = Sequential()
model.add(Dense(4, init='lecun_uniform', input_shape=(2,)))
model.add(Activation('relu'))    
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('relu'))    
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('linear'))     
rms = RMSprop()
model.compile(loss='mse', optimizer=rms)

So, it is pretty simple network with 3 layers each with 4 neurons. In stable-baselines you used default MlpPolicy with two layers with 64 neurons. You can easily specify net structure by passing to model policy_kwargs parameter that can look like this:

policy_kwargs = dict(        
        net_arch=[4, 4, 4]
    )

and your DQN model can be initilized in following way:

model = DQN('MlpPolicy', env, policy_kwargs=policy_kwargs, verbose=1)

In addition. In your first example, author creates simple DQN model with one network. However, in such frameworks as stable-baselines, DQN algorithm includes two same-structured networks for train and for evaluation. This is usefull for more complex problem, while for such simple problems as yours it can works badly.

Mikhail
  • 395
  • 3
  • 17