I am building my first dash app and am trying to have a few user inputs update a plot. The first user-input is a DateRangePicker with a button to submit the date range chosen. This works properly using a callback. In the callback, I call a function that takes this date range, filters the dataframe, makes some calculations to determine where to place horizontal lines, and updates the plot.
I'd like a second component to be two radio buttons that let the user determine how the grid will look on the plot - and ideally, this wouldn't need a button to submit the choice (i.e. when the user clicks one of the radio buttons itself, it'll update the plot automatically). To update the plot this way, I need to have the information calculated in the prior callback, or just recalculate everything again. When I tried the later, it threw a TypeError of TypeError: callback_linevalue() takes 3 positional arguments but 4 were given
. I'm not sure (1) what's the best way to do this and (2) why I'm getting that TypeError.
This is what the app looks like currently:
Below is the code in my app with some notes to make it easier to follow.
# do date filter and recalculate stds
def calc_stds(d, start_date, end_date):
df_period = d.query(f"date>=@start_date & date<=@end_date").reset_index(drop=True)
mean = df_period.spread.mean()
std = df_period.spread.std()
std_1up = mean + std
std_1down = mean - std
std_2up = mean + 2*std
std_2down = mean - 2*std
return df_period, mean, std_1up, std_1down, std_2up, std_2down
# function that creates the graph
def create_graph(plot_df, mean=0, up1=0, down1=0, up2=0, down2=0, linevalue='meanstd'):
fig = go.Figure()
fig.add_trace(go.Scatter(x=plot_df['date'],
y=plot_df['spread']))
fig.update_traces(marker=dict(size=3))
if linevalue=='meanstd':
hline_color = "black" # "#848484"
fig.add_shape(type='line',
x0=plot_df['date'].min(),
x1=plot_df['date'].max(),
y0=mean,
y1=mean,
xref='x',
yref='y',
line=dict(color=hline_color,
width=1),
)
fig.add_shape(type='line',
x0=plot_df['date'].min(),
x1=plot_df['date'].max(),
y0=up1,
y1=up1,
xref='x',
yref='y',
line=dict(color=hline_color,
width=1,
dash='dash'),
)
fig.add_shape(type='line',
x0=plot_df['date'].min(),
x1=plot_df['date'].max(),
y0=up2,
y1=up2,
xref='x',
yref='y',
line=dict(color=hline_color,
width=1,
dash='dot'),
)
fig.add_shape(type='line',
x0=plot_df['date'].min(),
x1=plot_df['date'].max(),
y0=down1,
y1=down1,
xref='x',
yref='y',
line=dict(color=hline_color,
width=1,
dash='dash'),
)
fig.add_shape(type='line',
x0=plot_df['date'].min(),
x1=plot_df['date'].max(),
y0=down2,
y1=down2,
xref='x',
yref='y',
line=dict(color=hline_color,
width=1,
dash='dot'),
)
# annotations
annotations = []
points = [up2, up1, mean, down1, down2]
labels = ["+2\u03C3", "+1\u03C3", "mean", "-1\u03C3", "-2\u03C3"]
for p, l in zip(points, labels):
annotations.append(dict(xref='paper',
x=1.005,
y=p,
xanchor='left',
yanchor='middle',
align='left',
text=f"{p:.2f} ({l})",
showarrow=False,
font=dict(size=12, color=hline_color)
))
fig.update_layout(annotations=annotations)
fig.add_annotation(x=0.03,
y=mean + .2,
xref='paper',
yref='y',
xanchor='left',
align='left',
borderpad=5,
text="Cheaper",
axref='pixel',
ayref='y',
ax=0.25,
ay=mean + 3.8,
arrowhead=1,
arrowsize=1.,
arrowside='start',
arrowwidth=1.5,
arrowcolor="#767676",
showarrow=True,
font=dict(size=15,
color="#767676"
))
fig.add_annotation(x=0.03,
y=mean - .2,
xref='paper',
yref='y',
xanchor='left',
align='left',
borderpad=5,
text="More Expensive",
axref='pixel',
ayref='y',
ax=0.25,
ay=mean - 3.8,
arrowhead=1,
arrowsize=1.,
arrowside='start',
arrowwidth=1.5,
arrowcolor="#767676",
showarrow=True,
font=dict(size=15,
color="#767676"
))
fig.update_xaxes(showgrid=False if linevalue=='meanstd' else True,
range=[plot_df.date.min(), plot_df.date.max()])
fig.update_yaxes(showgrid=False if linevalue=='meanstd' else True,
zeroline=False,
title='Equity Risk Premium',
ticksuffix=" ",
range=[-.5, np.ceil(plot_df.spread.max())])
fig.update_layout(font_family="Avenir",
font_color="#4c4c4c",
font_size=14,
showlegend=False,
template=plot_settings.dockstreet_template,
margin=dict(t=10, b=10)
)
fig.update_layout(
xaxis=dict(
rangeselector=dict(
buttons=list([
dict(count=1,
label="1m",
step="month",
stepmode="backward"),
dict(count=6,
label="6m",
step="month",
stepmode="backward"),
dict(count=1,
label="YTD",
step="year",
stepmode="todate"),
dict(count=1,
label="1y",
step="year",
stepmode="backward"),
dict(count=2,
label="2y",
step="year",
stepmode="backward"),
dict(step="all")
])
),
rangeslider=dict(
visible=True,
range=[plot_df.date.min(), plot_df.date.max()]
),
type="date"
)
)
for ser in fig['data']:
ser['hovertemplate'] = "%{x|%b %-d, %Y}, %{y:.2f}<extra></extra>"
return fig
# DO INITIAL CALCULATION WITH DEFAULT DATE RANGE
df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, df.date.min(), df.date.max())
# CREATE INITIAL GRAPH WITH DEFAULT DATE RANGE
updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue='meanstd')
# APP.LAYOUT CODE IN HERE THAT I'VE REMOVED FOR BREVITY:
# daterangepicker id='my_date_range'
# submit button for daterangepicker id='submit_button'
# graph id='my_graph'
# radio buttons id='plot_lines'
@app.callback(Output('my_graph','figure'),
[Input('submit_button','n_clicks')],
[
State('my_date_range','start_date'),
State('my_date_range','end_date'),
State('plot_lines', 'value'),
])
def callback_dates(n_clicks, start_date, end_date, linevalue):
# when it gets passed into the input, converts it to a string
start = datetime.strptime(start_date[:10], '%Y-%m-%d')
end = datetime.strptime(end_date[:10], '%Y-%m-%d')
df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, start, end)
updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue=linevalue)
return updated_figure
@app.callback(Output('my_graph','figure'),
[Input('plot_lines','value'),
Input('my_date_range','start_date'),
Input('my_date_range','end_date')]
)
def callback_linevalue(linevalue, start_date, end_date):
start = datetime.strptime(start_date[:10], '%Y-%m-%d')
end = datetime.strptime(end_date[:10], '%Y-%m-%d')
df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, start, end)
updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue=linevalue)
return updated_figure