6

I am trying to calculate monthly rolling window regressions and return predicted values as a new column in the data frame. I know that Pandas has rolling regression capabilities (pandas.ols) that are in the process of being depreciated, so I'm interested in a solution that uses statsmodels or something similar.

I'd like to calculate monthly rolling regressions (12 month window, 6 month minimum) and save each month's prediction back to a new column in the original data frame. While my question is different, the closest solution I've found is in the answer to this question. Based on that answer I've tried this (the data is below):

import pandas as pd 
import statsmodels.api as sm

def grp_ols_predict(df, xcols,  ycol):
    return sm.OLS(df[ycol], df[xcols]).fit().predict()
retdata['predicted_y'] = retdata.groupby('id').apply(grp_ols_predict, xcols=['constant','x1', 'x2', 'x3'], ycol='y')

There are two issues unresolved at this point.
1. This code runs without errors but returns all NaN values for predicted_y.
2. The regression above is not rolling window. The syntax for this is straightforward in pandas.ols, but not in statsmodels. However, it seems that the idea is for the pandas.ols syntax to work in statsmodels at some point. The following code is rolling window, but will be depriciated in a future version and is not grouped by id:

model = pd.ols(y='y', x=retdata[['x1','x2','x3']], window_type='rolling', window=12, min_periods=6, intercept=True)
retdata['predicted_y'] = model.y_predict

My question is essentially "appending predicted values and residuals to pandas dataframe" with two additional complications (1) rolling window and (2) grouping by id.

Finally, a sample of the data I'm using:

{'constant': {0: 1,  1: 1,  2: 1,  3: 1,  4: 1,  5: 1,  6: 1,  7: 1,  8: 1,  9: 1,  10: 1,  11: 1,  12: 1,  13: 1,  14: 1,  15: 1,  16: 1,  17: 1,  18: 1,  19: 1,  20: 1,  21: 1,  22: 1,  23: 1,  24: 1,  25: 1,  26: 1,  27: 1,  28: 1,  29: 1,  30: 1,  31: 1,  32: 1,  33: 1,  34: 1,  35: 1,  36: 1,  37: 1,  38: 1,  39: 1,  40: 1,  41: 1,  42: 1,  43: 1,  44: 1,  45: 1,  46: 1,  47: 1,  48: 1,  49: 1,  50: 1,  51: 1,  52: 1,  53: 1,  54: 1,  55: 1,  56: 1,  57: 1,  58: 1,  59: 1,  60: 1,  61: 1,  62: 1,  63: 1,  64: 1,  65: 1,  66: 1,  67: 1,  68: 1,  69: 1,  70: 1,  71: 1,  72: 1,  73: 1,  74: 1,  75: 1,  76: 1,  77: 1,  78: 1,  79: 1,  80: 1,  81: 1,  82: 1,  83: 1},
'id': {0: 11111,  1: 11111,  2: 11111,  3: 11111,  4: 11111,  5: 11111,  6: 11111,  7: 11111,  8: 11111,  9: 11111,  10: 11111,  11: 11111,  12: 11111,  13: 11111,  14: 11111,  15: 11111,  16: 11111,  17: 11111,  18: 11111,  19: 11111,  20: 11111,  21: 11111,  22: 11111,  23: 11111,  24: 22222,  25: 22222,  26: 22222,  27: 22222,  28: 22222,  29: 22222,  30: 22222,  31: 22222,  32: 22222,  33: 22222,  34: 22222,  35: 22222,  36: 22222,  37: 22222,  38: 22222,  39: 22222,  40: 22222,  41: 22222,  42: 22222,  43: 22222,  44: 22222,  45: 22222,  46: 22222,  47: 22222,  48: 22222,  49: 22222,  50: 22222,  51: 22222,  52: 22222,  53: 22222,  54: 22222,  55: 22222,  56: 22222,  57: 22222,  58: 22222,  59: 22222,  60: 33333,  61: 33333,  62: 33333,  63: 33333,  64: 33333,  65: 33333,  66: 33333,  67: 33333,  68: 33333,  69: 33333,  70: 33333,  71: 33333,  72: 33333,  73: 33333,  74: 33333,  75: 33333,  76: 33333,  77: 33333,  78: 33333,  79: 33333,  80: 33333,  81: 33333,  82: 33333,  83: 33333},
'month': {0: 1,  1: 2,  2: 3,  3: 4,  4: 5,  5: 6,  6: 7,  7: 8,  8: 9,  9: 10,  10: 11,  11: 12,  12: 1,  13: 2,  14: 3,  15: 4,  16: 5,  17: 6,  18: 7,  19: 8,  20: 9,  21: 10,  22: 11,  23: 12,  24: 1,  25: 2,  26: 3,  27: 4,  28: 5,  29: 6,  30: 7,  31: 8,  32: 9,  33: 10,  34: 11,  35: 12,  36: 1,  37: 2,  38: 3,  39: 4,  40: 5,  41: 6,  42: 7,  43: 8,  44: 9,  45: 10,  46: 11,  47: 12,  48: 1,  49: 2,  50: 3,  51: 4,  52: 5,  53: 6,  54: 7,  55: 8,  56: 9,  57: 10,  58: 11,  59: 12,  60: 1,  61: 2,  62: 3,  63: 4,  64: 5,  65: 6,  66: 7,  67: 8,  68: 9,  69: 10,  70: 11,  71: 12,  72: 1,  73: 2,  74: 3,  75: 4,  76: 5,  77: 6,  78: 7,  79: 8,  80: 9,  81: 10,  82: 11,  83: 12},
'x1': {0: 4.8399999999999999,  1: 1.4099999999999999,  2: 4.1299999999999999,  3: 3.1499999999999999,  4: -3.98,  5: -0.10000000000000001,  6: -4.5,  7: 3.79,  8: -0.84999999999999998,  9: -4.4199999999999999,  10: -0.46000000000000002,  11: 8.7100000000000009,  12: 2.4900000000000002,  13: 2.8700000000000001,  14: 0.63,  15: 0.28999999999999998,  16: 1.25,  17: -2.4300000000000002,  18: -0.80000000000000004,  19: 3.2599999999999998,  20: -1.1399999999999999,  21: 0.52000000000000002,  22: 4.5999999999999996,  23: 0.62,  24: 4.8399999999999999,  25: 1.4099999999999999,  26: 4.1299999999999999,  27: 3.1499999999999999,  28: -3.98,  29: -0.10000000000000001,  30: -4.5,  31: 3.79,  32: -0.84999999999999998,  33: -4.4199999999999999,  34: -0.46000000000000002,  35: 8.7100000000000009,  36: 2.4900000000000002,  37: 2.8700000000000001,  38: 0.63,  39: 0.28999999999999998,  40: 1.25,  41: -2.4300000000000002,  42: -0.80000000000000004,  43: 3.2599999999999998,  44: -1.1399999999999999,  45: 0.52000000000000002,  46: 4.5999999999999996,  47: 0.62,  48: -3.29,  49: -4.8499999999999996,  50: -1.29,  51: -5.6799999999999997,  52: -2.9399999999999999,  53: -1.5600000000000001,  54: 5.04,  55: -3.8399999999999999,  56: 4.75,  57: -0.85999999999999999,  58: -12.74,  59: 0.57999999999999996,  60: 5.5700000000000003,  61: 1.29,  62: 4.0300000000000002,  63: 1.55,  64: 2.7999999999999998,  65: -1.2,  66: 5.6500000000000004,  67: -2.71,  68: 3.77,  69: 4.1799999999999997,  70: 3.1200000000000001,  71: 2.8100000000000001,  72: -3.3199999999999998,  73: 4.6500000000000004,  74: 0.42999999999999999,  75: -0.19,  76: 2.0600000000000001,  77: 2.6099999999999999,  78: -2.04,  79: 4.2400000000000002,  80: -1.97,  81: 2.52,  82: 2.5499999999999998,  83: -0.059999999999999998},
'x2': {0: 7.4400000000000004,  1: 1.8999999999999999,  2: 2.5699999999999998,  3: -0.47999999999999998,  4: -1.1000000000000001,  5: -1.4299999999999999,  6: -1.5,  7: -0.19,  8: 0.40999999999999998,  9: -1.78,  10: -2.8300000000000001,  11: 3.2799999999999998,  12: 6.1100000000000003,  13: 1.3899999999999999,  14: -0.27000000000000002,  15: -0.02,  16: -2.79,  17: 0.32000000000000001,  18: -2.8900000000000001,  19: -4.0700000000000003,  20: -2.6899999999999999,  21: -2.71,  22: -1.1200000000000001,  23: -1.8600000000000001,  24: 7.4400000000000004,  25: 1.8999999999999999,  26: 2.5699999999999998,  27: -0.47999999999999998,  28: -1.1000000000000001,  29: -1.4299999999999999,  30: -1.5,  31: -0.19,  32: 0.40999999999999998,  33: -1.78,  34: -2.8300000000000001,  35: 3.2799999999999998,  36: 6.1100000000000003,  37: 1.3899999999999999,  38: -0.27000000000000002,  39: -0.02,  40: -2.79,  41: 0.32000000000000001,  42: -2.8900000000000001,  43: -4.0700000000000003,  44: -2.6899999999999999,  45: -2.71,  46: -1.1200000000000001,  47: -1.8600000000000001,  48: -3.5,  49: -3.9900000000000002,  50: -2.8700000000000001,  51: -3.9900000000000002,  52: -6.1200000000000001,  53: -2.9399999999999999,  54: 7.8600000000000003,  55: -2.04,  56: 2.9100000000000001,  57: -0.17000000000000001,  58: -7.7000000000000002,  59: -5.3300000000000001,  60: 0.44,  61: -0.42999999999999999,  62: 0.83999999999999997,  63: -2.4300000000000002,  64: 1.6899999999999999,  65: 1.1699999999999999,  66: 1.8799999999999999,  67: 0.25,  68: 2.9399999999999999,  69: -1.52,  70: 1.25,  71: -0.47999999999999998,  72: 0.87,  73: 0.34000000000000002,  74: -1.8500000000000001,  75: -4.1900000000000004,  76: -1.8500000000000001,  77: 3.0099999999999998,  78: -4.2199999999999998,  79: 0.40000000000000002,  80: -3.7999999999999998,  81: 4.2800000000000002,  82: -2.0499999999999998,  83: 2.5899999999999999},
'x3': {0: 1.3500000000000001,  1: -1.3400000000000001,  2: -4.0,  3: 0.73999999999999999,  4: -1.3799999999999999,  5: -2.0,  6: 0.14000000000000001,  7: 2.7200000000000002,  8: -2.9500000000000002,  9: -0.47999999999999998,  10: -1.75,  11: -0.23999999999999999,  12: 2.0600000000000001,  13: -2.75,  14: -1.6599999999999999,  15: 0.39000000000000001,  16: -2.73,  17: -2.4199999999999999,  18: 0.77000000000000002,  19: 4.6399999999999997,  20: 0.5,  21: 1.3200000000000001,  22: 4.7599999999999998,  23: -2.2599999999999998,  24: 1.3500000000000001,  25: -1.3400000000000001,  26: -4.0,  27: 0.73999999999999999,  28: -1.3799999999999999,  29: -2.0,  30: 0.14000000000000001,  31: 2.7200000000000002,  32: -2.9500000000000002,  33: -0.47999999999999998,  34: -1.75,  35: -0.23999999999999999,  36: 2.0600000000000001,  37: -2.75,  38: -1.6599999999999999,  39: 0.39000000000000001,  40: -2.73,  41: -2.4199999999999999,  42: 0.77000000000000002,  43: 4.6399999999999997,  44: 0.5,  45: 1.3200000000000001,  46: 4.7599999999999998,  47: -2.2599999999999998,  48: 2.7000000000000002,  49: 1.7,  50: 2.8300000000000001,  51: 5.6900000000000004,  52: 0.20999999999999999,  53: 1.4199999999999999,  54: -5.1799999999999997,  55: 1.1899999999999999,  56: 2.1099999999999999,  57: 1.74,  58: 4.0099999999999998,  59: 4.2400000000000002,  60: 0.94999999999999996,  61: 0.11,  62: -0.26000000000000001,  63: 0.56999999999999995,  64: 2.4900000000000002,  65: -0.13,  66: 0.60999999999999999,  67: -2.77,  68: -1.2,  69: 1.1000000000000001,  70: 0.26000000000000001,  71: -0.31,  72: -2.1299999999999999,  73: -0.37,  74: 5.0300000000000002,  75: 1.1000000000000001,  76: -0.35999999999999999,  77: -0.66000000000000003,  78: -0.02,  79: -0.55000000000000004,  80: -1.1899999999999999,  81: -1.6799999999999999,  82: -2.98,  83: 2.1200000000000001},
'y': {0: 37.543945819999998,  1: 8.9742475529999997,  2: -2.3528754309999997,  3: 13.13251636,  4: -1.60429428,  5: -11.956497779999999,  6: -19.876604879999999,  7: -2.325516618,  8: -4.7618724569999999,  9: 3.1666054689999998,  10: -1.625982086,  11: 23.14051619,  12: 36.241578869999998,  13: -4.0393970439999993,  14: -1.5464071159999999,  15: -5.8638777849999997,  16: 1.1173513309999998,  17: -7.7348398829999994,  18: 1.1975707259999999,  19: 8.1657380679999996,  20: 1.0988696200000001,  21: -4.8912916910000002,  22: 15.31432558,  23: -0.49755575099999999,  24: 2.439007991,  25: 3.7788248100000001,  26: 6.2406021170000008,  27: 0.070041193000000002,  28: -8.2320061649999996,  29: -3.0580604539999996,  30: -8.1230234560000003,  31: 4.824015073,  32: -0.082216824000000008,  33: -1.0699493369999999,  34: 2.0965058669999999,  35: 10.147223650000001,  36: 9.3610165409999997,  37: 0.50276726500000002,  38: 3.731305892,  39: 0.98107468400000009,  40: 3.3937931360000002,  41: -1.445663699,  42: 2.2321845640000002,  43: 2.2707284099999998,  44: -0.48955173399999996,  45: -5.1661444639999994,  46: 1.776962626,  47: 2.8132786730000001,  48: 8.3333586369999999,  49: -0.59700207599999999,  50: 0.0,  51: -5.4461723210000006,  52: -3.2260780789999997,  53: 0.71489267299999992,  54: -0.78864414099999991,  55: -3.936371727,  56: -14.285801190000001,  57: 8.6241378770000008,  58: -5.0419731539999999,  59: -6.8867527329999998,  60: 2.7716522460000004,  61: 2.1129326050000001,  62: 2.8956834530000002,  63: 15.714036009999999,  64: 6.1329305139999999,  65: -1.017191977,  66: -7.8303661889999994,  67: 5.6218592960000002,  68: -0.35928143700000004,  69: 6.385216346,  70: 8.4875017649999993,  71: -1.8882769469999998,  72: 1.1494252870000001,  73: 1.9820295980000002,  74: 6.9955625160000006,  75: -1.4393754569999999,  76: 2.0297029700000002,  77: 1.8563751830000002,  78: 3.5011990410000005,  79: 5.9082483779999997,  80: 2.0471054369999999,  81: 1.272648835,  82: 2.49201278,  83: -2.844593181},
'year': {0: 1971,  1: 1971,  2: 1971,  3: 1971,  4: 1971,  5: 1971,  6: 1971,  7: 1971,  8: 1971,  9: 1971,  10: 1971,  11: 1971,  12: 1972,  13: 1972,  14: 1972,  15: 1972,  16: 1972,  17: 1972,  18: 1972,  19: 1972,  20: 1972,  21: 1972,  22: 1972,  23: 1972,  24: 1971,  25: 1971,  26: 1971,  27: 1971,  28: 1971,  29: 1971,  30: 1971,  31: 1971,  32: 1971,  33: 1971,  34: 1971,  35: 1971,  36: 1972,  37: 1972,  38: 1972,  39: 1972,  40: 1972,  41: 1972,  42: 1972,  43: 1972,  44: 1972,  45: 1972,  46: 1972,  47: 1972,  48: 1973,  49: 1973,  50: 1973,  51: 1973,  52: 1973,  53: 1973,  54: 1973,  55: 1973,  56: 1973,  57: 1973,  58: 1973,  59: 1973,  60: 2013,  61: 2013,  62: 2013,  63: 2013,  64: 2013,  65: 2013,  66: 2013,  67: 2013,  68: 2013,  69: 2013,  70: 2013,  71: 2013,  72: 2014,  73: 2014,  74: 2014,  75: 2014,  76: 2014,  77: 2014,  78: 2014,  79: 2014,  80: 2014,  81: 2014,   82: 2014,    83: 2014}}
Community
  • 1
  • 1
Arthur Morris
  • 1,253
  • 1
  • 15
  • 21

1 Answers1

6

pandas' rolling seems to have some limitations. First, it seems impossible to pass an entire frame of data via the apply. Instead, only values of a single column are passed. To get around this, we pass the index via the apply which allows to get the relevant data frame subset within the apply function itself.

Second, the returned value needs to be a float. This is of no use here because the sm.OLS.predict returns an iterable of values. To fix this, we save the results in an extra container as a side effect and extract it later on.

def ols_predict(indices, result, ycol, xcols):
    roll_df = df.loc[indices] # get relevant data frame subset
    result[indices[-1]] = sm.OLS(roll_df[ycol], roll_df[xcols]).fit().predict()
    return 0 # value is irrelvant here

# define kwargs to be fet to the ols_predict
kwargs = {"xcols": ['constant','x1', 'x2', 'x3'], 
          "ycol": 'y', "result": {}}

# iterate id's sub data frames and call ols for rolling windows
df["identifier"] = df.index
for idx, sub_df in df.groupby("id"):
    sub_df["identifier"].rolling(12, min_periods=6).apply(ols_predict, kwargs=kwargs)

# write results back to original df
df["parameters"] = pd.Series(kwargs["result"])

# showing the last 5 computed values
print(df["parameters"].tail())

79    [2.71069564365, 3.86510820198, 3.65972798601, ...
80    [4.05363775104, 4.22653362401, 3.03918230523, ...
81    [3.55589161647, 2.49348201521, 1.20113347853, ...
82    [2.28561308212, 1.0537258681, 2.40806914305, 4...
83    [-0.428928897229, 3.22009689097, 3.30943586961...
Name: parameters, dtype: object

Overall, the workarounds are rather ugly using side effects. However, it accomplishes what you require. You can now modify your OLS function to whatever needed.

pansen
  • 6,433
  • 4
  • 19
  • 32
  • First, thank you. Second, is there a way to get the predicted value for `y` in each period without writing a function that does the calculation (e.g. for one observation: y_hat is the sum of the products of the parameters and their corresponding data). I'm thinking of something like the post-estimation command `model.y_predict` that I referenced above. – Arthur Morris Mar 02 '17 at 21:14
  • I'm a bit unsure about what you want to achieve here. You want to build a model containing 12 observations and then you want to predict the latest observation with the fitted model containing the last observation itself? – pansen Mar 02 '17 at 21:40
  • Yes. I suppose the term 'predict' isn't perfect here. I'm trying to separate `'y'` into a portion explained by the model (what I'm calling the prediction) and a portion unexplained by the model (the residual). – Arthur Morris Mar 02 '17 at 21:51
  • So you mean something like [R squared](https://en.wikipedia.org/wiki/Coefficient_of_determination)? – pansen Mar 02 '17 at 22:05
  • I'm not after a fit statistic, rather y* from equation (2) in [this answer](http://stats.stackexchange.com/questions/18233/what-are-the-predicted-values-returned-by-the-predict-function-in-r-when-using). I'm trying to reproduce the [Stata post-estimation command `predict`](http://www.stata.com/manuals13/rpredict.pdf). Thanks for working with me on this! – Arthur Morris Mar 02 '17 at 22:19
  • I think see what is going on (now I understand my confusion). `df["parameters"]` is returning _all_ 12 (6 to 12) predictions for each observation that was used to calculate the model for that month. I think I just need to keep the last element of each cell in `df["parameters"]`. Thanks again! – Arthur Morris Mar 02 '17 at 23:22
  • Glad to have helped you. You can modify your `ols_predict` function to whatever statistic you need. – pansen Mar 03 '17 at 07:44
  • I get KerError: -1 when I try to run this now, relating to this line: `result[indices[-1]] = sm.OLS(roll_df[ycol], roll_df[xcols]).fit().predict()`. Any ideas? – Vash Jul 14 '20 at 11:16
  • @Vash What pandas version are you using? – pansen Jul 27 '20 at 08:15
  • @pansen I'm on Pandas 1.0.1.; Python 3.8.1. – Vash Jul 27 '20 at 16:42