4

Suppose I have the following pandas data frame:

import pandas as pd
d = {'Person': ['Bob']*9 + ['Alice']*9,
    'Time': ['Morining']*3 + ['Noon']*3 + ['Evening']*3 + ['Morining']*3 + ['Noon']*3 + ['Evening']*3,
    'Color': ['Red','Blue','Green']*6,
    'Energy': [1,5,4,7,3,6,8,4,2,9,8,5,2,6,7,3,8,1]}
df = pd.DataFrame(d)

enter image description here

How can I create a plot like this? enter image description here (Excuse the crude plotting)

I've tried tricking scatter, strip and box plots into this, but with no success.

Thank you!

petezurich
  • 9,280
  • 9
  • 43
  • 57
soungalo
  • 1,106
  • 2
  • 19
  • 34

4 Answers4

3
  • generate a scatter trace per Person
  • a bit of logic on x so that each person is offset. Hence hovertext and xaxis ticks
import plotly.graph_objects as go

xbase = pd.Series(df["Time"].unique()).reset_index().rename(columns={"index":"x",0:"Time"})
dfp = df.merge(xbase, on="Time").set_index("Person")

go.Figure(
    [
        go.Scatter(
            name=p,
            x=dfp.loc[p, "x"] + i/10,
            y=dfp.loc[p, "Energy"],
            text=dfp.loc[p, "Time"],
            mode="markers",
            marker={"color": dfp.loc[p, "Color"], "symbol":i, "size":10},
            hovertemplate="(%{text},%{y})"
        )
        for i, p in enumerate(dfp.index.get_level_values("Person").unique())
    ]
).update_layout(xaxis={"tickmode":"array", "tickvals":xbase["x"], "ticktext":xbase["Time"]})

enter image description here

Rob Raymond
  • 29,118
  • 3
  • 14
  • 30
2

You've already received some great suggestions, but since you're still wondering about:

What if I also want the colors to show in the legend?

I'd just like to chip in that px.scatter comes really close to being an optimal approach right out of the box. The only thing that's missing is jitter. Still, the plot below can be produced by these few lines of code:

fig = px.scatter(df, x = 'Time', y = 'Energy', color = 'Color', symbol = 'Person')

fig.for_each_trace(lambda t: t.update(marker_color = t.name.split(',')[0],
                                      name = t.name.split(',')[1], x = [1,2,3]))

fig.for_each_trace(lambda t: t.update(x=tuple([x + 0.2 for x in list(t.x)])) if t.name == ' Alice' else ())

enter image description here

Complete code:

import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

# data
d = {'Person': ['Bob']*9 + ['Alice']*9,
    'Time': ['Morining']*3 + ['Noon']*3 + ['Evening']*3 + ['Morning']*3 + ['Noon']*3 + ['Evening']*3,
    'Color': ['Red','Blue','Green']*6,
    'Energy': [1,5,4,7,3,6,8,4,2,9,8,5,2,6,7,3,8,1]}
df = pd.DataFrame(d)

# figure setup
fig = px.scatter(df, x = 'Time', y = 'Energy', color = 'Color', symbol = 'Person')

# some customizations in order to get to the desired result:
fig.for_each_trace(lambda t: t.update(marker_color = t.name.split(',')[0],
                                      name = t.name.split(',')[1],
                                     x = [1,2,3]))
# jitter
fig.for_each_trace(lambda t: t.update(x=tuple([x + 0.2 for x in list(t.x)])) if t.name == ' Alice' else ())


# layout
fig.update_layout(xaxis={"tickmode":"array","tickvals":[1,2,3],"ticktext":df.Time.unique()})
    
fig.show()

Room for improvement:

Some elements of the snippet above could undoubtedly be made more dynamic, like x = [1,2,3] which should take into account a varying number of elements on the x-axis. The same goes for the number of people and the arguments used for jitter. But I can look into that too if this is something you can use.

vestland
  • 55,229
  • 37
  • 187
  • 305
1

You can go through each row the DataFrame using itertuples (better performance than iterrows), and map 'Morning', 'Noon', and 'Evening' values to 1,2,3, respectively, and then jitter the x-values by mapping 'Bob' to '-0.05' and 'Alice' to 0.05 and adding these values to each of the x-values. You can also pass the 'Color' information to the marker_color argument.

Then map the tickvalues of 1,2,3 back to 'Morning','Noon' and 'Evening' and also use a legendgroup to get only one Bob and one Alice legend marker to display (to stop the marker for each trace from displaying in the legend)

import pandas as pd
import plotly.graph_objects as go

d = {'Person': ['Bob']*9 + ['Alice']*9,
    'Time': ['Morning']*3 + ['Noon']*3 + ['Evening']*3 + ['Morning']*3 + ['Noon']*3 + ['Evening']*3,
    'Color': ['Red','Blue','Green']*6,
    'Energy': [1,5,4,7,3,6,8,4,2,9,8,5,2,6,7,3,8,1]}
df = pd.DataFrame(d)

shapes = {'Bob': 'circle', 'Alice': 'diamond'}
time = {'Morning':1, 'Noon':2, 'Evening':3}
jitter = {'Bob': -0.05, 'Alice': 0.05}

fig = go.Figure()
## position 1 of each row is Person... position 4 is the Energy value
s = df.Person.shift() != df.Person
name_changes = s[s].index.values
for row in df.itertuples():
    if row[0] in name_changes:
        fig.add_trace(go.Scatter(
            x=[time[row[2]] + jitter[row[1]]],
            y=[row[4]],
            legendgroup=row[1],
            name=row[1],
            mode='markers',
            marker_symbol=shapes[row[1]],
            marker_color=row[3],
            showlegend=True
        ))
    else:
        fig.add_trace(go.Scatter(
            x=[time[row[2]] + jitter[row[1]]],
            y=[row[4]],
            legendgroup=row[1],
            name=row[1],
            mode='markers',
            marker_symbol=shapes[row[1]],
            marker_color=row[3],
            showlegend=False
        ))

fig.update_traces(marker=dict(size=12,line=dict(width=2,color='DarkSlateGrey')))
fig.update_layout(
    xaxis=dict(
        tickmode='array',
        tickvals=list(time.values()),
        ticktext=list(time.keys())
    )
)
fig.show() 

enter image description here

Derek O
  • 16,770
  • 4
  • 24
  • 43
1

In case you only want to go with matplotlib and don't want any extra dependencies, here is a sample code. (Pandas operations groupbys etc are left for you to optimize)

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
from matplotlib.lines import Line2D

df = pd.DataFrame(
    {
        'Person': ['Bob'] * 9 + ['Alice'] * 9,
        'Time': ['Morning'] * 3
        + ['Noon'] * 3
        + ['Evening'] * 3
        + ['Morning'] * 3
        + ['Noon'] * 3
        + ['Evening'] * 3,
        'Color': ['Red', 'Blue', 'Green'] * 6,
        'Energy': [1, 5, 4, 7, 3, 6, 8, 4, 2, 9, 8, 5, 2, 6, 7, 3, 8, 1],
    }
)

plt.figure()

x = ['Morning', 'Noon', 'Evening']

# Transform function
offset = lambda p: transforms.ScaledTranslation(
    p / 72.0, 0, plt.gcf().dpi_scale_trans
)
trans = plt.gca().transData

# Use this to center transformation
start_offset = -len(df['Person'].unique()) // 2

# Define as many markers as people you have
markers = ['o', '^']

# Use this for custom legend
custom_legend = []

# Do this if you need to aggregate
df = df.groupby(['Person', 'Time', 'Color'])['Energy'].sum().reset_index()

df = df.set_index('Time')
for i, [person, pgroup] in enumerate(df.groupby('Person')):
    pts = (i + start_offset) * 10
    marker = markers[i]
    transform = trans + offset(pts)

    # This is for legend, not plotted
    custom_legend.append(
        Line2D(
            [0],
            [0],
            color='w',
            markerfacecolor='black',
            marker=marker,
            markersize=10,
            label=person,
        )
    )

    for color, cgroup in pgroup.groupby('Color'):
        mornings = cgroup.loc[cgroup.index == 'Morning', 'Energy'].values[0]
        noons = cgroup.loc[cgroup.index == 'Noon', 'Energy'].values[0]
        evenings = cgroup.loc[cgroup.index == 'Evening', 'Energy'].values[0]

        # This stupid if is because you need to define at least one non
        # transformation scatter be it first or whatever.
        if pts == 0:
            plt.scatter(
                x,
                [mornings, noons, evenings],
                c=color.lower(),
                s=25,
                marker=marker,
            )
        else:
            plt.scatter(
                x,
                [mornings, noons, evenings],
                c=color.lower(),
                s=25,
                marker=marker,
                transform=transform,
            )

plt.ylabel('Energy')
plt.xlabel('Time')
plt.legend(handles=custom_legend)
plt.margins(x=0.5)
plt.show()

plooot

tchar
  • 838
  • 9
  • 12