1

I am building my first dash app and am trying to have a few user inputs update a plot. The first user-input is a DateRangePicker with a button to submit the date range chosen. This works properly using a callback. In the callback, I call a function that takes this date range, filters the dataframe, makes some calculations to determine where to place horizontal lines, and updates the plot.

I'd like a second component to be two radio buttons that let the user determine how the grid will look on the plot - and ideally, this wouldn't need a button to submit the choice (i.e. when the user clicks one of the radio buttons itself, it'll update the plot automatically). To update the plot this way, I need to have the information calculated in the prior callback, or just recalculate everything again. When I tried the later, it threw a TypeError of TypeError: callback_linevalue() takes 3 positional arguments but 4 were given. I'm not sure (1) what's the best way to do this and (2) why I'm getting that TypeError.

This is what the app looks like currently: enter image description here

Below is the code in my app with some notes to make it easier to follow.

# do date filter and recalculate stds
def calc_stds(d, start_date, end_date):
    df_period = d.query(f"date>=@start_date & date<=@end_date").reset_index(drop=True)
    mean = df_period.spread.mean()
    std = df_period.spread.std()
    std_1up = mean + std
    std_1down = mean - std
    std_2up = mean + 2*std
    std_2down = mean - 2*std

    return df_period, mean, std_1up, std_1down, std_2up, std_2down

# function that creates the graph 
def create_graph(plot_df, mean=0, up1=0, down1=0, up2=0, down2=0, linevalue='meanstd'):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=plot_df['date'],
                             y=plot_df['spread']))

    fig.update_traces(marker=dict(size=3))

    if linevalue=='meanstd':
        hline_color = "black"  # "#848484"

        fig.add_shape(type='line',
                      x0=plot_df['date'].min(),
                      x1=plot_df['date'].max(),
                      y0=mean,
                      y1=mean,
                      xref='x',
                      yref='y',
                      line=dict(color=hline_color,
                                width=1),
        )

        fig.add_shape(type='line',
                      x0=plot_df['date'].min(),
                      x1=plot_df['date'].max(),
                      y0=up1,
                      y1=up1,
                      xref='x',
                      yref='y',
                      line=dict(color=hline_color,
                                width=1,
                                dash='dash'),
                      )

        fig.add_shape(type='line',
                      x0=plot_df['date'].min(),
                      x1=plot_df['date'].max(),
                      y0=up2,
                      y1=up2,
                      xref='x',
                      yref='y',
                      line=dict(color=hline_color,
                                width=1,
                                dash='dot'),
                      )

        fig.add_shape(type='line',
                      x0=plot_df['date'].min(),
                      x1=plot_df['date'].max(),
                      y0=down1,
                      y1=down1,
                      xref='x',
                      yref='y',
                      line=dict(color=hline_color,
                                width=1,
                                dash='dash'),
                      )

        fig.add_shape(type='line',
                      x0=plot_df['date'].min(),
                      x1=plot_df['date'].max(),
                      y0=down2,
                      y1=down2,
                      xref='x',
                      yref='y',
                      line=dict(color=hline_color,
                                width=1,
                                dash='dot'),
                      )

        # annotations
        annotations = []

        points = [up2, up1, mean, down1, down2]
        labels = ["+2\u03C3", "+1\u03C3", "mean", "-1\u03C3", "-2\u03C3"]

        for p, l in zip(points, labels):
            annotations.append(dict(xref='paper',
                                    x=1.005,
                                    y=p,
                                    xanchor='left',
                                    yanchor='middle',
                                    align='left',
                                    text=f"{p:.2f} ({l})",
                                    showarrow=False,
                                    font=dict(size=12, color=hline_color)
                                    ))

        fig.update_layout(annotations=annotations)

    fig.add_annotation(x=0.03,
                       y=mean + .2,
                       xref='paper',
                       yref='y',
                       xanchor='left',
                       align='left',
                       borderpad=5,
                       text="Cheaper",
                       axref='pixel',
                       ayref='y',
                       ax=0.25,
                       ay=mean + 3.8,
                       arrowhead=1,
                       arrowsize=1.,
                       arrowside='start',
                       arrowwidth=1.5,
                       arrowcolor="#767676",
                       showarrow=True,
                       font=dict(size=15,
                                 color="#767676"
                                 ))

    fig.add_annotation(x=0.03,
                       y=mean - .2,
                       xref='paper',
                       yref='y',
                       xanchor='left',
                       align='left',
                       borderpad=5,
                       text="More Expensive",
                       axref='pixel',
                       ayref='y',
                       ax=0.25,
                       ay=mean - 3.8,
                       arrowhead=1,
                       arrowsize=1.,
                       arrowside='start',
                       arrowwidth=1.5,
                       arrowcolor="#767676",
                       showarrow=True,
                       font=dict(size=15,
                                 color="#767676"
                                 ))



    fig.update_xaxes(showgrid=False if linevalue=='meanstd' else True,
                     range=[plot_df.date.min(), plot_df.date.max()])

    fig.update_yaxes(showgrid=False if linevalue=='meanstd' else True,
                     zeroline=False,
                     title='Equity Risk Premium',
                     ticksuffix="  ",
                     range=[-.5, np.ceil(plot_df.spread.max())])

    fig.update_layout(font_family="Avenir",
                      font_color="#4c4c4c",
                      font_size=14,
                      showlegend=False,
                      template=plot_settings.dockstreet_template,
                      margin=dict(t=10, b=10)
                      )

    fig.update_layout(
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1,
                         label="1m",
                         step="month",
                         stepmode="backward"),
                    dict(count=6,
                         label="6m",
                         step="month",
                         stepmode="backward"),
                    dict(count=1,
                         label="YTD",
                         step="year",
                         stepmode="todate"),
                    dict(count=1,
                         label="1y",
                         step="year",
                         stepmode="backward"),
                    dict(count=2,
                         label="2y",
                         step="year",
                         stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(
                visible=True,
                range=[plot_df.date.min(), plot_df.date.max()]
            ),
            type="date"
        )
    )

    for ser in fig['data']:
        ser['hovertemplate'] = "%{x|%b %-d, %Y}, %{y:.2f}<extra></extra>"

    return fig

# DO INITIAL CALCULATION WITH DEFAULT DATE RANGE
df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, df.date.min(), df.date.max())

# CREATE INITIAL GRAPH WITH DEFAULT DATE RANGE
updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue='meanstd')

# APP.LAYOUT CODE IN HERE THAT I'VE REMOVED FOR BREVITY:
# daterangepicker id='my_date_range'
# submit button for daterangepicker id='submit_button'
# graph id='my_graph'
# radio buttons id='plot_lines'

@app.callback(Output('my_graph','figure'),
              [Input('submit_button','n_clicks')],
              [
                  State('my_date_range','start_date'),
                  State('my_date_range','end_date'),
                  State('plot_lines', 'value'),
              ])
def callback_dates(n_clicks, start_date, end_date, linevalue):
    # when it gets passed into the input, converts it to a string
    start = datetime.strptime(start_date[:10], '%Y-%m-%d')
    end = datetime.strptime(end_date[:10], '%Y-%m-%d')

    df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, start, end)

    updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue=linevalue)

    return updated_figure

@app.callback(Output('my_graph','figure'),
    [Input('plot_lines','value'),
     Input('my_date_range','start_date'),
     Input('my_date_range','end_date')]
)
def callback_linevalue(linevalue, start_date, end_date):
    start = datetime.strptime(start_date[:10], '%Y-%m-%d')
    end = datetime.strptime(end_date[:10], '%Y-%m-%d')

    df_date_filter, m, u1, d1, u2, d2 = calc_stds(df, start, end)

    updated_figure = create_graph(df_date_filter, mean=m, up1=u1, down1=d1, up2=u2, down2=d2, linevalue=linevalue)

    return updated_figure

0 Answers0