-1

I'm following this SO answer of plotting a SVC, but I'd like to tweak it so I know which color is associated with the target values (1 or 0). My initial solution was to increment through the data and set the marker based on the target value, however, I believe since I'm using cmap, c is expected to be an array, but I'm passing in y[i]. Trying to figure out how to resolve this.

for i in range(X0.shape[0]):
    ax.scatter(X0[i], X1[i], c=y[i], marker=markers[i], cmap=plt.cm.coolwarm, s=20, edgecolors='k')

I've also tried using a colorbar but its black and white.

PCM=ax.get_children()[2] 
plt.colorbar(PCM, ax=ax) 

y: [1 1 1 1 1 0 0 0 0 0]

X0: [ 375 378 186 186 186 69 1048 515 1045 730]

X1: [159 73 272 58 108 373 373 373 373 267]

Brosef
  • 2,945
  • 6
  • 34
  • 69

1 Answers1

0

Just to have an alternative to a scatter plot plus colorbar, I'd mention a line plot with markers but no lines, plus a legend. To do that, you have to group your data beforehand but it's just a few lines of a function...

import numpy as np 
import matplotlib.pyplot as plt                                                                        

y =  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]                                                                    
X0 = [375, 378, 186, 186, 186, 69, 1048, 515, 1045, 730] 
X1 = [159, 73, 272, 58, 108, 373, 373, 373, 373, 267]                                                  

def group(x, y): 
    groups = [[], []] 
    for val, key in zip(x, y): 
        groups[key].append(val) 
    return groups                                                                                      

for v0, v1, lt, label in zip(group(X0,y), group(X1,y), ('rx','bo'), ('0','1')): 
    plt.plot(v0, v1, lt, label=label) 
plt.legend();                                                                                          

enter image description here


I'd like to build on this idea, even if there are excellent packages that can help with the grouping of data AND plotting.

We can define a function like this

def scatter_group(keys, values, ax=None, fmt=str):
    """side effect: scatter plots values grouped by keys, return an axes.

    keys, N vector like of integers or strings;
    values, 2xN vector like of numbers, x = values[0], y = values[1];
    ax, the axes on which we want to plot, by default None;
    fmt, a function to format the key value in the legend label, def `str`.
    """

    from matplotlib.pyplot import subplots
    from itertools import product, cycle 

    # -> 'or', 'ob', 'ok', 'xr', 'xb', ..., 'or', 'ob', ...
    linetypes = cycle((''.join(mc) for mc in product('ox*^>', 'rbk')))

    d = {} 
    if ax == None: (_, ax) = plt.subplots() 

    for k, *v in zip(keys, *values): d.setdefault(k,[]).append(v) 
    for k, lt in zip(d, linetypes): 
        x, y = list(zip(*d[k])) 
        ax.plot(x, y, lt, label=fmt(k)) 
    ax.legend() 
    return ax                                                                   

and use it as follows

In [148]: fig, (ax0, ax1) = plt.subplots(1,2)                                             

In [149]: ax00 = scatter_group(y, (X0, X1), ax=ax0)                                               

In [150]: ax0 is ax00                                                                     
Out[150]: True

In [151]: scatter_group(y, (X1, X0), ax=ax1, fmt=lambda k:"The key is %s"%k)                      
Out[151]: <matplotlib.axes._subplots.AxesSubplot at 0x7fc88f57bac8>

In [152]: ax0.set_xlabel('X0')                                                            
Out[152]: Text(0.5, 23.52222222222222, 'X0')

In [153]: ax1.set_xlabel('X1')                                                            
Out[153]: Text(0.5, 23.52222222222222, 'X1')

enter image description here

gboffi
  • 22,939
  • 8
  • 54
  • 85