1

I have found and adapted the following code snippets for generating diagnostic plots for linear regression. This is currently done using the following functions:

def residual_plot(some_values):
    plot_lm_1 = plt.figure(1)    
    plot_lm_1 = sns.residplot()
    plot_lm_1.axes[0].set_title('title')
    plot_lm_1.axes[0].set_xlabel('label')
    plot_lm_1.axes[0].set_ylabel('label')
    plt.show()


def qq_plot(residuals):
    QQ = ProbPlot(residuals)
    plot_lm_2 = QQ.qqplot()    
    plot_lm_2.axes[0].set_title('title')
    plot_lm_2.axes[0].set_xlabel('label')
    plot_lm_2.axes[0].set_ylabel('label')
    plt.show()

which are called with something like:

plot1 = residual_plot(value_set1)
plot2 = qq_plot(value_set1)
plot3 = residual_plot(value_set2)
plot4 = qq_plot(value_set2)

How can I create subplots so that these 4 plots are displayed in a 2x2 grid?
I have tried using:

fig, axes = plt.subplots(2,2)
    axes[0,0].plot1
    axes[0,1].plot2
    axes[1,0].plot3
    axes[1,1].plot4
    plt.show()

but receive the error:

AttributeError: 'AxesSubplot' object has no attribute 'plot1'

Should I set up the axes attributes from within the functions or where else?

Andreuccio
  • 1,053
  • 2
  • 18
  • 32

1 Answers1

4

You should create a single figure with four subplot axes that will serve as input axes for your custom plot functions, following

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import probplot


def residual_plot(x, y, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = sns.residplot(x, y, ax = ax1)
    ax1.set_xlabel("Data")
    ax1.set_ylabel("Residual")
    ax1.set_title("Residuals")
    return p


def qq_plot(x, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = probplot(x, plot = ax1)
    ax1.set_xlim(-3, 3)
    return p


if __name__ == "__main__":
    # Generate data
    x = np.arange(100)
    y = 0.5 * x
    y1 = y + np.random.randn(100)
    y2 = y + np.random.randn(100)

    # Initialize figure and axes
    fig = plt.figure(figsize = (8, 8), facecolor = "white")
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 2)
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)

    # Plot data
    p1 = residual_plot(y, y1, ax1)
    p2 = qq_plot(y1, ax2)
    p3 = residual_plot(y, y2, ax3)
    p4 = qq_plot(y2, ax4)

    fig.tight_layout()
    fig.show()

I do not know what your ProbPlot function is, so I just took SciPy's one.

Kefeng91
  • 802
  • 6
  • 10