I've solved the issue.
You need to have a slightly funky setup where a function outputs another function. At this link , they give the following example:
def linear_schedule(initial_value):
"""
Linear learning rate schedule.
:param initial_value: (float or str)
:return: (function)
"""
if isinstance(initial_value, str):
initial_value = float(initial_value)
def func(progress):
"""
Progress will decrease from 1 (beginning) to 0
:param progress: (float)
:return: (float)
"""
return progress * initial_value
return func
So essentially, what you have to do is write a function, myscheduler(), which doesn't necessarily need inputs, and you need the output of that function to be another function which has "progress" (measured from 1 to 0 as training goes on) to be the only input. That "progress" value will be passed to the function by PPO itself. So, I suppose the "under the hood" order of events is something like:
- Your learning_rate scheduling function is called
- Your learning_rate scheduling function outputs a function which takes progress as input
- SB3's PPO (or other algorithm) input its current progress into that function
- Function outputs necessary learning_rate, and the model grabs it and goes with that output.
In my case, I wrote something like this:
def lrsched():
def reallr(progress):
lr = 0.003
if progress < 0.85:
lr = 0.0005
if progress < 0.66:
lr = 0.00025
if progress < 0.33:
lr = 0.0001
return lr
return reallr
Then, you use that function in the following way:
model = PPO(...learning_rate=lrsched())