1

First of all my hmmlearn version is 0.3.0b (installed using conda).

I am trying to implement a GMMHMM model in hmmlearn but I am getting:

ValueError: n_samples=3 should be >= n_clusters=5

To become more specific I have a model of 4 states and 5 mixture Gaussians (clusters) and my input X variable has shape(20,3) as mentioned in the documentation i.e. (n_samples, n_features).

Here is the code so that create the error:

import numpy as np
from hmmlearn import hmm

size = 30
data = np.concatenate((np.random.normal(0,1,size), np.random.normal(5,2,size)))
np.random.shuffle(data)
x = np.reshape(data,(-1,3))

model = hmm.GMMHMM(n_components=4, n_mix=5)
model.fit(x)

Could anyone find any reasoning for that or is it a bug of the library? I couldn't find online examples of implementing the GMMHMM model.

mrt
  • 339
  • 1
  • 2
  • 14

1 Answers1

0

I have the same problem.

It seems like it wants multiple samples in a row on which the Gaussian Mixtures would be trained. I.e. a 3 dimensional array where each row is an array of samples and each sample is an array of features of the Multivariate Gaussian Mixtures:

[
  [[a b],[a b],[a b],[a b]]
  [[a b],[a b],[a b],[a b]]
  [[a b],[a b],[a b],[a b]]
  [[a b],[a b],[a b],[a b]]
  ...
]

That is because when I tried to reshape it as a single row of multiple samples, I overcame the error. But then I got an error of insufficient n_samples for training of the HMM itself, i.e. missing rows.

[
  [a a a a a a a a a],[b b b b b b b b b]
]

The problem then is that the fit() method only takes 2d arrays, not 3d. So I couldn't make this work, I have no idea how Multivariate Gaussian Mixtures in HMMGMM should be done with hmmlearn.

But because the nature of the features was the same, I stacked them one on top of another with different offsets as a one continuous stream and trained a simple Gaussian Mixture of one variable (not multivariate):

[
  [a + 1],
  [b - 1],
  [a + 1],
  [b - 1],
  ...
]

Instead of a single 3d "hill", there will be 2 bumps in the 2d probability density function, which is good enough for discrimination between different input time series.

mhrvth
  • 87
  • 1
  • 1
  • 7