Actually you want to draw a color table or something close to a Gantt diagram. I don't know really convenience ways to do this in Python.
One solution is to use matplotlib.pyplot.grid
(doc). This discussion provides one solution.
Another solution is to study the plotly
package. It provides really pretty outputs for table and gantt graphs (doc).
Here, I will show you a similar output using hbar
from matplotlib. The main idea is to rebuild a grid of cells. Each row represent a class (e.g. A
, B
or C
). Each row is composed by the same number of cells. Each cell has the same width. The color of a cell is defined by the binary translation of your data.
To adjust the x-axis
labels, they are just shifted manually.
# Import module
import seaborn
import matplotlib.pyplot as plt
import numpy as np
##########################
# Your input #
##########################
A = [3, 5, 7]
B = [4, 5, 7]
C = [2, 3, 4, 5, 6, 7]
##########################
# Prepra the data #
##########################
my_list = [A, B, C]
segments = 8
cell_width = 1000
nb_x_range = [i * cell_width for i in range(2, segments)]
classes = ["A", "B", "C"]
colors_def = {'A': {0: "w", 1: "b"},
'B': {0: "w", 1: "g"},
'C': {0: "w", 1: "y"},
}
def create_data(my_list):
data = np.zeros((segments + 1, len(classes)))
for i, sub_list in enumerate(my_list):
for elt in sub_list:
data[elt, i] = 1
return data
data = create_data(my_list)
print(data)
# [[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 1.]
# [1. 0. 1.]
# [0. 1. 1.]
# [1. 1. 1.]
# [0. 0. 1.]
# [1. 1. 1.]
# [0. 0. 0.]]
y_pos = np.arange(len(classes))
# left alignment of data starts at zero
left = np.zeros(len(my_list)) - cell_width/2
##########################
# Create plot #
##########################
# Set sea born for prettier graph
seaborn.set()
# create figure
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
# Set X axis (shifted)
ax.set_xlim((min(nb_x_range)-cell_width/2, max(nb_x_range)+cell_width/2))
# For each cell
for d in data:
# Define color for each row
colors = [colors_def[classes[i]][bi] for i, bi in enumerate(d)]
# e.g. colors = [colors_def["A"][d[0]], colors_def["B"][d[1]], colors_def["C"][d[2]]]
# Draw cell
ax.barh(y_pos, cell_width,
height=1, # Heights of horizontal bars
color=colors, # Colors
left=left) # Left padd from y-axis
# Update margin
left += cell_width
# Axis graduation
ax.set_yticks(y_pos)
ax.set_xticks(nb_x_range)
ax.set_yticklabels(classes)
ax.set_xlabel('Stacked bar')
plt.show()
The output looks like this:
