0

using this code in order to generate a 3D plot:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d


fig = plt.figure()
plt.figure(figsize=(15, 12))
ax = plt.axes(projection="3d")

x= x_pca[:,0]
y= x_pca[:,1]
z = x_pca[:,2]


y_train_new = y_train_orig.argmax(1)

ax.scatter3D(x, y, z, c = y_train_new, marker = 'o', alpha=0.6, cmap='inferno')
plt.title("3D Scatterplot: 95% of the variability captured", pad = 15)
ax.set_xlabel('First principal component')
ax.set_ylabel('Second principal component')
ax.set_zlabel('Third principal component')

I would like to change the colors of the classes.

This is the output that I get:

enter image description here

Could the problem lie with my y_train_new data? It is an array of dtype =int64?

How do I go about changing the color of for the different classes? Thanks!

Data:

feat1 feat2 feat3 feat4 Label
-3.8481877 -0.47685334 0.63422906 1.0396314 1
-2.320888 0.65347993 1.1519914 0.12997247 1
1.5827686 1.4119303 -1.7410104 -4.6962333 1
-0.1337152 0.13315737 -1.6648949 -1.4205348 1
-0.4028037 1.332986 1.3618442 0.3292255 1
-0.015517877 1.346349 1.4083523 0.87017965 1
-0.2669228 0.5478992 -0.06730786 -1.5959451 1
-0.03318152 0.3263167 -2.116833 -5.4616213 1
0.4588691 0.6723614 -1.617398 -4.3511734 1
0.5899199 0.66525555 -1.694493 -3.9452586 1
1.610061 2.4186094 1.8807093 1.3764497 0
1.7985699 2.4387648 1.6306056 1.1184534 0
-9.222036 -9.9776 -9.832 -9.909746 0
0.21364458 -1.0171559 -4.9093766 -6.2154694 0
-0.019955145 -1.1677283 -4.6549516 -5.9503417 0
0.44730473 -0.77167743 -4.7527356 -5.971007 0
-0.16508447 -0.005777468 -1.5020386 -4.49326 0
-0.8654994 -0.54387957 -1.300646 -4.621529 0
-1.7471086 -2.0005553 -1.7533782 -2.6065414 0
-1.5313624 -1.6995796 -1.4394685 -2.600004 0
Joe
  • 357
  • 2
  • 10
  • 32
  • 1
    Hi @Joe - Can you pls share some sample data (few rows) of what is used here? The colors could probably be because black and yellow are the first and last colors of `inferno` – Redox Jul 02 '22 at 13:24
  • @Redox, I have added some data as per your request. Thanks! – Joe Jul 02 '22 at 14:16
  • 1
    Hi @joe- Is this data = x_pca with 0-2 the x,y,z and Label the y_train_new? – Redox Jul 02 '22 at 16:22
  • Hello @Redox, yes, columns 0-2 is the x,y,z data and the last column is the classes. – Joe Jul 02 '22 at 19:10

1 Answers1

1

Thanks for the data provided. As you can see from your data, Label is either 0 or 1. As you chose inferno as the cmap, the extremes of that colormap are the yellow and black, which is what you are seeing.

The easier way to change the colors is to choose a colormap that has the colors you need, that should help. The set of colormaps is available here. You should be able to choose any of the colormaps there. Also, I have modified your code slightly to change the cmap and added an edge so you can see the new plot. If you replace your ax.scatter(..) with this....

#Using tab10 which is a discrete cmap
ax.scatter3D(x, y, z, c = y_train_new, marker = 'o', alpha=0.8, s=100, cmap='tab10', edgecolor='k')

... you will get this plot

enter image description here

But, if you need specific colors, then you can create your own colormap and use that instead. The documentation for this is available here. Do remember that the number of unique values in your Label column should have a color while creating mycmap below. As an example, if you replace your ax.scatter() with this...

from matplotlib.colors import ListedColormap
mycmap = ListedColormap(["blue", "red"])
ax.scatter3D(x, y, z, c = y_train_new, marker = 'o', alpha=0.8, s=100, cmap=mycmap, edgecolor='k')

...your colormap will change to this.

enter image description here

Redox
  • 9,321
  • 5
  • 9
  • 26