0

I'm drawing a 3D scatterplot, and intend to give different color to each marker based on the value of y-axis (country) label. I have the following code. The colors of markers aren't what they need to be. I think I'm doing something wrong or in-efficient in the for loop. Can you point out the mistake that I'm making?

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np    

src = ['facebook', 'google', 'amazon','facebook','facebook','google']
year = [2014,2014,2013,2013,2012,2013]
country = ['uk','ru','de','us','uk','us']
avg = [154,267,187,312,274,439]
colors = {'uk' : 'b',
          'de' : 'y',
          'ru' : 'r', 
          'us' : 'c'}

unique_src, idx_src = np.unique(src, return_inverse=True)
unique_cty, idx_cty = np.unique(country, return_inverse=True)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for col, val in zip(colors, country):
    ax.scatter(idx_src, idx_cty, year, s=avg, c=colors[val])
plt.yticks(range(len(unique_cty)), unique_cty, rotation=340)
plt.xticks(range(len(unique_src)), unique_src, rotation=45, horizontalalignment='right')
ax.set_zticks(np.unique(year))
plt.show()

Also, when calling the scatter function, I can't write:
ax.scatter(source, country, year, s=avg, c=colors[val])

as I get the following error message:

ValueError: could not convert string to float: facebook

Why is that so? I am using matplotlib version 2.1.2

DYZ
  • 55,249
  • 10
  • 64
  • 93
SaadH
  • 1,158
  • 2
  • 23
  • 38
  • Possible duplicate of [Is there a way to make matplotlib scatter plot marker or color according to a discrete variable in a different column?](https://stackoverflow.com/questions/24297097/is-there-a-way-to-make-matplotlib-scatter-plot-marker-or-color-according-to-a-di) – DYZ Jan 30 '18 at 02:04

1 Answers1

2

It seems you want to colorize the points by country.

c = [colors[val] for val in country]
ax.scatter(idx_src, idx_cty, year, s=avg, c=c)

Complete example:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np    

src = ['facebook', 'google', 'amazon','facebook','facebook','google']
year = [2014,2014,2013,2013,2012,2013]
country = ['uk','ru','de','us','uk','us']
avg = [154,267,187,312,274,439]
colors = {'uk' : 'b',
          'de' : 'y',
          'ru' : 'r', 
          'us' : 'c'}

unique_src, idx_src = np.unique(src, return_inverse=True)
unique_cty, idx_cty = np.unique(country, return_inverse=True)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

c = [colors[val] for val in country]
ax.scatter(idx_src, idx_cty, year, s=avg, c=c)

plt.yticks(range(len(unique_cty)), unique_cty, rotation=340)
plt.xticks(range(len(unique_src)), unique_src, rotation=45, horizontalalignment='right')
ax.set_zticks(np.unique(year))
plt.show()

enter image description here

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712