2

I would like to visualize the EM steps taken in a GMM model but don't know how I would go about doing that.

I've generated some synthetic data and fitted a model:

a = np.random.normal(loc=[2,2,2], scale=1.0, size=(100,3))
b = np.random.normal(loc=[5,5,5], scale=1.0, size=(100,3))
c = np.random.normal(loc = [7,7,7], scale = 1.0, size = (100,3))

data = np.concatenate((a,b,c), axis = 0)

df = pd.DataFrame(data, columns=['x', 'y', 'z'])

gm = GaussianMixture(n_components = 3, random_state = 213).fit(df)

res = gm.fit_predict(df)

I've used graspologic (package for graph statistics) to visualize the end result but would like to see how the EM algorithm iterates through the data.

Any thoughts on how I can implement this?

Jean-Paul Azzopardi
  • 401
  • 1
  • 2
  • 10

1 Answers1

0

If you want to keep using sklearn's GaussianMixture method, you can try the following where you use the max_iter option to specify how many iterations of the EM algorithm to run. You can repeatedly, fit and predict to view the output after a specified number of iterations.

import numpy as np
import pandas as pd
from PIL import Image as im
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture

a = np.random.normal(loc=[2,2,2], scale=1.0, size=(200,3))
b = np.random.normal(loc=[3,3,3], scale=1.0, size=(200,3))
c = np.random.normal(loc = [4,4,4], scale = 1.0, size = (200,3))

data = np.concatenate((a,b,c), axis = 0)

df = pd.DataFrame(data, columns=['x', 'y', 'z'])
TEMPFILE = 'temp.png'

def snap(data, labels, size):
    fig, ax = plt.subplots()
    ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', s=size)
    fig.savefig(TEMPFILE)
    plt.close()
    return im.fromarray(np.asarray(im.open(TEMPFILE)))

images = []
for i in range(20):
    gm = GaussianMixture(n_components = 3, random_state = 213, max_iter=i).fit(df)
    probs = gm.predict_proba(df)
    labels = [np.argmax(np.array(p)) for p in probs] # create a hard assignment
    size = 50 * probs.max(1) ** 8
    images.append(snap(data, labels, size))

images[0].save(
    'gmm.gif',
    optimize=False,
    save_all=True,
    append_images=images[1:],
    loop=0,
    duration=100
)

which produces the following animation:

enter image description here

Galletti_Lance
  • 509
  • 2
  • 4
  • 15