0

I have this code for displaying the first 10 columns of some pandas dataframe...

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


## DATAFRAME ##

np.random.seed(102030)              # Se deben correr las 2 lineas juntas SIEMPRE.
df_np = np.random.randint(-100,100,(20,10))
df_index = list(range(1,21))
df_col = ['A','B','C','D','E','F','G','H','I','J']

df = pd.DataFrame(data = df_np, 
                  index = df_index, 
                  columns = df_col)


## PLOT ##

f_sample = 1     # First SAMPLE
l_sample = 20    # Last SAMPLE

data_color = 'red'     # FILL COLOR
data_alpha = 0.5    # (0-1) 0-TRANSPARENT 1-OPAQUE --> alpha parameter

df_nan = df * 0.7

fig = plt.subplots(figsize=(12,10)) #Set up the plot axes


ax1 = plt.subplot2grid((1,10), (0,0), rowspan=1, colspan = 1) 
ax2 = plt.subplot2grid((1,10), (0,1), rowspan=1, colspan = 1)
ax3 = plt.subplot2grid((1,10), (0,2), rowspan=1, colspan = 1)
ax4 = plt.subplot2grid((1,10), (0,3), rowspan=1, colspan = 1)
ax5 = plt.subplot2grid((1,10), (0,4), rowspan=1, colspan = 1)
ax6 = plt.subplot2grid((1,10), (0,5), rowspan=1, colspan = 1)
ax7 = plt.subplot2grid((1,10), (0,6), rowspan=1, colspan = 1)
ax8 = plt.subplot2grid((1,10), (0,7), rowspan=1, colspan = 1)
ax9 = plt.subplot2grid((1,10), (0,8), rowspan=1, colspan = 1)
ax10 = plt.subplot2grid((1,10), (0,9), rowspan=1, colspan = 1)

axes = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10]


for i, ax in enumerate(axes):
    ax.plot(df_nan.iloc[:,i], df_nan.index, lw=0)
    ax.set_ylim(l_sample, f_sample)
    ax.set_xlim(0, 1)
    ax.set_title(df_nan.columns[i])                                                        
    ax.set_facecolor('whitesmoke')
    ax.fill_betweenx(df_nan.index, 0, df_nan.iloc[:,i], facecolor= data_color, alpha= data_alpha )       
    plt.setp(ax.get_xticklabels(), visible = False)
    if i > 0:
        plt.setp(ax.get_yticklabels(), visible = False)
        
ax1.set_ylabel('Sample', fontsize=18)
plt.subplots_adjust(wspace=0)
                                      
plt.show()

It works well when the length of the df is greater or equal than 10 columns..

I would like to have a code for displaying all columns of any df keeping the same display as the previous plot.

Thank you!!

grojar
  • 1
  • 2

1 Answers1

0

UPDATED ANSWER

Solution

Create axes within a for loop. By changing the range of the loop, you will be able to change the number of columns displayed. This number can be arbitrarily large.

Code

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## DATAFRAME ##

numCols = 13    # THIS IS THE NUMBER OF COLUMNS IN THE DF WHICH WILL BE DISPLAYED

np.random.seed(102030)              # Se deben correr las 2 lineas juntas SIEMPRE.
df_np = np.random.randint(-100,100,(20,numCols))
df_index = list(range(1,21))
df_col = ['A','B','C','D','E','F','G','H','I','J','K','L','M']

df = pd.DataFrame(data = df_np, 
                  index = df_index, 
                  columns = df_col[:numCols])


## PLOT ##

f_sample = 1     # First SAMPLE
l_sample = 20    # Last SAMPLE

data_color = 'red'     # FILL COLOR
data_alpha = 0.5    # (0-1) 0-TRANSPARENT 1-OPAQUE --> alpha parameter

df_nan = df * 0.7

fig = plt.subplots(figsize=(12,10)) #Set up the plot axes

for i in range(numCols):
    ax = plt.subplot2grid((1,numCols), (0,i), rowspan=1, colspan = 1)
    ax.plot(df_nan.iloc[:,i], df_nan.index, lw=0)
    ax.set_ylim(l_sample, f_sample)
    ax.set_xlim(0, 1)
    ax.set_title(df_nan.columns[i])                                                        
    ax.set_facecolor('whitesmoke')
    ax.fill_betweenx(df_nan.index, 0, df_nan.iloc[:,i], facecolor= data_color, alpha= data_alpha )       
    plt.setp(ax.get_xticklabels(), visible = False)
    if i > 0:
        plt.setp(ax.get_yticklabels(), visible = False)
    else:    # Sets the label for the leftmost subplot
        ax.set_ylabel('Sample', fontsize=18)
    
plt.subplots_adjust(wspace=0)                                   
plt.show()

ORIGINAL ANSWER

Solution

allData = [data1, data2, ... , data10]

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(4,3))

for ax, data in zip(axes.flat, allData):
    ax.imshow(data)

Explanation

allData is some array-like data structure where each element is the data for each subplot. When creating fig,axes the fig is a Figure and the variable axes is an array of subplots. nrows and ncols determine the number and position of the subplots. If you call a for loop which zips elements from axes and allData then make sure to flatten your array of subplots.

Link to documentation: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html

trent
  • 359
  • 1
  • 9