4

I'd like to plot a convergence process of the MLE algorithm with the plotly library.

Requirements:

  • the points have to be colored colored in the colors of the clusters, and change accordingly each iteration
  • the centroids of the clusters should be plotted on each iteration.

A plot of a single iteration may be produced by Code 1, with the desired output shown in Figure 1:

Code 1

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )
fig.show()

Figure 1

Figure 1

I've seen this tutorial, but it seems that you can plot only a single trace in a graph_objects.Frame(), and Code 2 represents a simple example for producing an animated scatter plot of all the points, where each frame plots points from different cluster and the centroids:

Code 2

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 1', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 2', marker_color=colors[1])]),
            go.Frame(data=[go.Scatter(x=A[3:5][:,0], y=A[3:5][:,1], mode='markers', name='cluster 3', marker_color=colors[2])]),
            go.Frame(data=[go.Scatter(x=A[5:8][:,0], y=A[5:8][:,1], mode='markers', name='cluster 4', marker_color=colors[3])]),
            go.Frame(data=[go.Scatter(x=A[8:][:,0], y=A[8:][:,1], mode='markers', name='cluster 5', marker_color=colors[4])]),
            go.Frame(data=[go.Scatter(x=[centroids[0][0]], y=[centroids[0][1]], mode='markers', name='centroid of cluster 1', marker_color=colors[0], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[1][0]], y=[centroids[1][1]], mode='markers', name='centroid of cluster 2', marker_color=colors[1], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[2][0]], y=[centroids[2][1]], mode='markers', name='centroid of cluster 3', marker_color=colors[2], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[3][0]], y=[centroids[3][1]], mode='markers', name='centroid of cluster 4', marker_color=colors[3], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[4][0]], y=[centroids[4][1]], mode='markers', name='centroid of cluster 5', marker_color=colors[4], marker_symbol='x')])]
)
fig.show()

Why does Code 2 does not fit my needs:

  • I need to plot all the frames produced by Code 2 in a single frame each iteration of the algorithm (i.e. each frame of the desired solution will look like Figure 1)

What I have tried:

  • I have tried producing a graph_objects.Figure(), and adding it to a graph_objects.Frame() as shown in Code 3, but have gotten Error 1.

Code 3:

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )

animated_fig = go.Figure(
    data=[go.Scatter(x=A[:3][:, 0], y=A[:3][:, 1], mode='markers', name=f'cluster 0', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[fig])]
)

animated_fig.show()

Error 1:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-681-11264f38e6f7> in <module>
     43                           args=[None])])]
     44     ),
---> 45     frames=[go.Frame(data=[fig])]
     46 )
     47 

~\Anaconda3\lib\site-packages\plotly\graph_objs\_frame.py in __init__(self, arg, baseframe, data, group, layout, name, traces, **kwargs)
    241         _v = data if data is not None else _v
    242         if _v is not None:
--> 243             self["data"] = _v
    244         _v = arg.pop("group", None)
    245         _v = group if group is not None else _v

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in __setitem__(self, prop, value)
   3973                 # ### Handle compound array property ###
   3974                 elif isinstance(validator, (CompoundArrayValidator, BaseDataValidator)):
-> 3975                     self._set_array_prop(prop, value)
   3976 
   3977                 # ### Handle simple property ###

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in _set_array_prop(self, prop, val)
   4428         # ------------
   4429         validator = self._get_validator(prop)
-> 4430         val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
   4431 
   4432         # Save deep copies of current and new states

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
   2671 
   2672             if invalid_els:
-> 2673                 self.raise_invalid_elements(invalid_els)
   2674 
   2675             v = to_scalar_or_list(res)

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in raise_invalid_elements(self, invalid_els)
    298                     pname=self.parent_name,
    299                     invalid=invalid_els[:10],
--> 300                     valid_clr_desc=self.description(),
    301                 )
    302             )

ValueError: 
    Invalid element(s) received for the 'data' property of frame
        Invalid elements include: [Figure({
    'data': [{'marker': {'color': 'red'},
              'mode': 'markers',
              'name': 'cluster 1',
              'type': 'scatter',
              'x': array([-1.30634452, -1.73005459,  0.58746435]),
              'y': array([ 0.15388112,  0.47452796, -1.86354483])},
             {'marker': {'color': 'green'},
              'mode': 'markers',
              'name': 'cluster 2',
              'type': 'scatter',
              'x': array([-1.73005459,  0.58746435, -0.27492892]),
              'y': array([ 0.47452796, -1.86354483, -0.20329897])},
             {'marker': {'color': 'blue'},
              'mode': 'markers',
              'name': 'cluster 3',
              'type': 'scatter',
              'x': array([ 0.58746435, -0.27492892,  0.21002816]),
              'y': array([-1.86354483, -0.20329897,  1.99487636])},
             {'marker': {'color': 'yellow'},
              'mode': 'markers',
              'name': 'cluster 4',
              'type': 'scatter',
              'x': array([-0.27492892,  0.21002816, -0.0148647 ]),
              'y': array([-0.20329897,  1.99487636,  0.73484184])},
             {'marker': {'color': 'magenta'},
              'mode': 'markers',
              'name': 'cluster 5',
              'type': 'scatter',
              'x': array([ 0.21002816, -0.0148647 ,  1.13589386]),
              'y': array([1.99487636, 0.73484184, 2.08810809])},
             {'marker': {'color': 'red', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 1',
              'type': 'scatter',
              'x': [9],
              'y': [6]},
             {'marker': {'color': 'green', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 2',
              'type': 'scatter',
              'x': [0],
              'y': [5]},
             {'marker': {'color': 'blue', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 3',
              'type': 'scatter',
              'x': [8],
              'y': [6]},
             {'marker': {'color': 'yellow', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 4',
              'type': 'scatter',
              'x': [7],
              'y': [1]},
             {'marker': {'color': 'magenta', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 5',
              'type': 'scatter',
              'x': [6],
              'y': [2]}],
    'layout': {'template': '...'}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])
  • I've succeeded to get all the points present in each frame with the use of plotly.express module, as shown in Code 3, but the only thing that is missing there is for the centroids to be marked as xs.

Code 3:

import plotly.express as px
import numpy as np
import pandas as pd

A = np.random.randn(200).reshape((100, 2))
iteration = np.array([1, 2, 3, 4, 5]).repeat(20)
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = np.random.randint(1, 6, size=100)
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

df = pd.DataFrame(dict(x1=A[:, 0], x2=A[:, 1], type='point', cluster=pd.Series(clusters, dtype='str'), iteration=iteration))
centroid_df = pd.DataFrame(dict(x1=centroids[:, 0], x2=centroids[:, 1], type='centroid', cluster=[1, 2, 3, 4, 5], iteration=[1, 2, 3, 4, 5]))
df = df.append(centroid_df, ignore_index=True)
px.scatter(df, x="x1", y="x2", animation_frame="iteration", color="cluster", hover_name="cluster", range_x=[-10,10], range_y=[-10,10])

I'd appreciate any help for achieving the desired result. Thanks.

rpanai
  • 12,515
  • 2
  • 42
  • 64
Michael
  • 2,167
  • 5
  • 23
  • 38

1 Answers1

4

You can add two traces per frame but apparently you need to define these two traces in the first data too. I added again the first two traces as a frame in order to have them visible in subsequent play. Here the full code

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None]),
                     dict(label="Pause",
                          method="animate",
                          args=[None,
                               {"frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}}],
                         )])]
    ),
    frames=[
    go.Frame(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ]),
    go.Frame(
        data=[
            go.Scatter(x=A[:3][:,0],
                       y=A[:3][:,1],
                       mode='markers',
                       name='cluster 2',
                       marker_color=colors[1]),
            go.Scatter(x=[centroids[1][0]],
                       y=[centroids[1][1]],
                       mode='markers',
                       name='centroid of cluster 2',
                       marker_color=colors[1],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[3:5][:,0],
                       y=A[3:5][:,1],
                       mode='markers',
                       name='cluster 3',
                       marker_color=colors[2]),
            go.Scatter(x=[centroids[2][0]],
                       y=[centroids[2][1]],
                       mode='markers',
                       name='centroid of cluster 3',
                       marker_color=colors[2],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[5:8][:,0],
                       y=A[5:8][:,1],
                       mode='markers',
                       name='cluster 4',
                       marker_color=colors[3]),
        go.Scatter(x=[centroids[3][0]],
                   y=[centroids[3][1]],
                   mode='markers',
                   name='centroid of cluster 4',
                   marker_color=colors[3],
                   marker_symbol='x')]),
    go.Frame(
        data=[
            go.Scatter(x=A[8:][:,0],
                       y=A[8:][:,1],
                       mode='markers',
                       name='cluster 5',
                       marker_color=colors[4]),
            go.Scatter(x=[centroids[4][0]],
                       y=[centroids[4][1]],
                       mode='markers',
                       name='centroid of cluster 5',
                       marker_color=colors[4],
                       marker_symbol='x')
        ]),
    ])
            
fig.show()

enter image description here

rpanai
  • 12,515
  • 2
  • 42
  • 64