1

I have a pandas dataframe that has data I want to plot, but I'd like to change the color of the dot based on the sex of the animal. I've tried a bunch of different ways to get this to work. First I tried to index a dictionary based on the df column called 'Sex'

figure = plt.figure(figsize=(20, 6))
axes = figure.add_subplot(1, 2, 1)
clr = {'M':'firebrick','F':'blueviolet', 'I':'beige'}
axes.scatter( data[ "Whole Weight"], data['Shucked Weight'],color=clr[str(data['Sex'])])
axes.set_ylabel( "Shucked Weight")
axes.set_xlabel( "Whole Weight")
axes.set_title("Whole Weight vs. Shucked Weight")

plt.show()
plt.close()

That gives me a bunch of key errors. Next I tried to loop through the df and add the column manually based on a row value:

for x1 in data['Sex']:
    if x1 == 'M':
        print(x1,)
        data['color'] = 'firebrick'
    elif x1 == 'F':
        data['color'] = 'blueviolet'
    else:
        data['color'] = 'bisque1'

I tried to make a dictionary from scratch that had the values in them:

weight_dict = pd.DataFrame(dict(whole = data['Whole Weight'], shucked = data['Shucked Weight'], sex = data['Sex'], color= some if statement that choked))

I tried to use the np.where statement, but I have 3 options for sex (Male, Female, and Infant, abbreviated as M,F, I)

data['color'] = np.where(data.Sex == 'M', 'Firebrick', (data.Sex == 'F', 'blueviolet','beige'))

And finally I got this to work:

def label_color(row):
    if row['Sex'] == 'M':
        return 'firebrick'
    elif row['Sex'] == 'F':
        return 'blueviolet'
    else:
        return 'beige'
data['color'] = data.apply(lambda row: label_color(row), axis=1)

but I wasn't all that satisfied with the solution. I really wanted the first solution to work where I just had a custom dictionary and looked it up during the call to axes.scatter, but the errors were just bizarre and not understandable.

Is there an easier solution to this madness?

Sean Mahoney
  • 131
  • 7

1 Answers1

2

I think you had it almost right at the first attempt.

I'd apply the dictionary to a new column that maps sex to color. Something like

df = pd.DataFrame(columns=["ww", "sw", "sex"])
df["ww"] = np.random.randn(500)
df["sw"] = np.random.randn(500)
df["sex"] = np.random.choice(["M", "F", "I"], size=500)

clr = {'M':'firebrick','F':'blueviolet', 'I':'yellow'}
df["color"] = df["sex"].apply(lambda x: clr[x])

plt.scatter(df["ww"], df["sw"], color=df["color"], alpha=0.7)

enter image description here

Or if you don't want a new column, or maybe the dictionary changes between scatter calls you can do

plt.scatter(df["ww"], df["sw"], color=df["sex"].apply(lambda x: clr[x]), alpha=0.7)

Not sure if is there a better solution using only dictionaries, but given you already have your data in pandas I'd say it's fine to use it.

filippo
  • 5,197
  • 2
  • 21
  • 44