0

I am trying to make a multicolor line plot using matplotlib. The color would change given a specific value in a column of my datataframe

time v1 v2 state
0 3.5 8 0
1 3.8 8.5 0
2 4.2 9 1
3 5 12 0
4 8 10 2

My code for now, which just display the plot normally without the color:

cols=['v1','v2']
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 15))
df.plot(x='time',y=cols,subplots=True, ax=axes)   
plt.legend()
plt.xticks(rotation=45)
plt.show()

The result would be something like that (2nd graph), with the line changing color given the column state (red,blue,green) with 3 distinct colors

result

Art
  • 132
  • 6
  • You would just create different plots for each interval, and overlay them on top of each other. You can either split your dataframe accordingly, or work directly on the selected data by using `plt.plot(x, y, color=...)` in a for loop. I prefer the second method. – 9769953 Aug 19 '22 at 08:55

2 Answers2

0
for state, prev, cur in zip(df['state'].iloc[1:], df.index[:-1], df.index[1:]):
    if state==0: 
        color='blue'
    elif state==1:
        color='orange'
    else:
        color='green'
    plt.plot([df["time"][prev],df["time"][cur]],df.loc[[prev,cur],['v1','v2']], c=color)
plt.xticks(rotation=45)
plt.show()
Art
  • 132
  • 6
0

If you want to avoid for loops:

from matplotlib.collections import LineCollection
# other imports...

df[["time_shift", "v1_shift", "v2_shift"]] = df.shift(-1)[["time", "v1", "v2"]]
df = df.dropna()

# bulid separate line segments
lines1 = zip(df[["time", "v1"]].values, df[["time_shift", "v1_shift"]].values)
lines2 = zip(df[["time", "v2"]].values, df[["time_shift", "v2_shift"]].values)

# map "state" to RGB values, use black if mapping does not exist
color_map = {
    0: (0.8, 0.1, 0.1),
    1: (0.1, 0.8, 0.1),
    2: (0.1, 0.1, 0.8),
}
colors = df["state"].apply(lambda x: color_map.get(x, (0, 0, 0))).tolist()

xlim = (df["time"].min(), df["time"].max())
ylim1 = (df["v1"].min(), df["v1"].max())
ylim2 = (df["v2"].min(), df["v2"].max())

fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].set_xlim(*xlim)
ax[0].set_ylim(*ylim1)
ax[0].add_collection(LineCollection(lines1, linestyle="solid", colors=colors))
ax[1].set_xlim(*xlim)
ax[1].set_ylim(*ylim2)
ax[1].add_collection(LineCollection(lines2, linestyle="solid", colors=colors))
plt.show()
bzu
  • 1,242
  • 1
  • 8
  • 14