0

I am trying to train a two-state Hidden Markov model with a scaled Baum-Welch, but I noticed when my emission sequence is too small. My probabilities turn to NaN in java. Is this normal? I have posted my code in java below:

import java.util.ArrayList;
/*
Scaled Baum-Welch Algorithm implementation
author: Ricky Chang
*/

public class HMModeltest {

public static double[][] stateTransitionMatrix = new double[2][2]; // State Transition Matrix
public static double[][] emissionMatrix; // Emission Probability Matrix
public static double[] pi = new double[2]; // Initial State Distribution

double[] scaler; // This is used for scaling to prevent underflow
private static int emissions_id = 1; // To identify if the emissions are for the price changes or spread changes
private static int numEmissions = 0; // The amount of emissions
private static int numStates = 2; // The number of states in hmm
public static double improvementVar; // Used to assess how much the model has improved
private static double genState; // Generated state, it is used to generate observations below

// Create an ArrayList to store the emissions
public static ArrayList<Integer> eSequence = new ArrayList<Integer>();


// Initialize H, emission_id: 1 is price change, 2 are spreads; count is for the amount of different emissions
public HMModeltest(int id, int count){
    emissions_id = id;
    numEmissions = count;

    stateTransitionMatrix = set2DValues(numStates,numStates); // Give the STM  row stochastic values
    emissionMatrix = new double[numStates][numEmissions];
    emissionMatrix = set2DValues(numStates,numEmissions); // Give the Emission probability matrix row stochastic values
    pi = set1DValues(numStates); // Give the initial matrix row stochastic values 
}

// Categorize the price change emissions; I may want to put these in the Implementation.
private int identifyE1(double e){

    if( e == 0) return 4;
    if( e > 0){
        if(e == 1) return 5;
        else if(e == 3) return 6;
        else if(e == 5) return 7;
        else return 8;
    }
    else{
        if(e == -1) return 3;
        else if(e == -3) return 2;
        else if(e == -5) return 1;
        else return 0;
    }
}

// Categorize the spread emissions
private int identifyE2(double e){

    if(e == 1) return 0;
    else if(e == 3) return 1;
    else return 2;
}

public void updateE(int emission){
    if(emissions_id == 1) eSequence.add( identifyE1(emission) );
    else eSequence.add( identifyE2(emission) );
}

// Used to intialize random row stochastic values to vectors
private double[] set1DValues(int col){
    double sum = 0;
    double temp = 0;
    double [] returnVector = new double[col];

    for(int i = 0; i < col; i++){
        temp = Math.round(Math.random() * 1000);
        returnVector[i] = temp;
        sum = sum + temp;
    }
    for(int i = 0; i < col; i++){
        returnVector[i] = returnVector[i] / sum;
    }

    return returnVector;
}

// Used to initialize random row stochastic values to matrices
public double[][] set2DValues(int row, int col){
    double sum = 0;
    double temp = 0;
    double[][] returnMatrix = new double[row][col];

    for(int i = 0; i < row; i++){
        for(int j = 0; j < col; j++){
            temp = Math.round(Math.random() * 1000);
            returnMatrix[i][j] = temp;
            sum = sum + temp;
        }
        for(int j = 0; j < col; j++){
            returnMatrix[i][j] = returnMatrix[i][j] / sum;
        }

        sum = 0;
    }

    return returnMatrix;
}

// Use forward algorithm to calculate alpha for all states and times
public double[][] forwardAlgo(int time){
    double alpha[][] = new double[numStates][time];
    scaler = new double[time];

    // Intialize alpha for time 0
    scaler[0] = 0; // c0 is for scaling purposes to avoid underflow
    for(int i = 0; i < numStates; i ++){
        alpha[i][0] = pi[i] * emissionMatrix[i][eSequence.get(0)];
        scaler[0] = scaler[0] + alpha[i][0];
    }

    // Scale alpha_0
    scaler[0] = 1 / scaler[0];
    for(int i = 0; i < numStates; i++){
        alpha[i][0] = scaler[0] * alpha[i][0];
    }

    // Use recursive method to calculate alpha
    double tempAlpha = 0;
    for(int t = 1; t < time; t++){
        scaler[t] = 0;
        for(int i = 0; i < numStates; i++){
            for(int j = 0; j < numStates; j++){
                tempAlpha = tempAlpha + alpha[j][t-1] * stateTransitionMatrix[j][i];
            }
            alpha[i][t] = tempAlpha * emissionMatrix[i][eSequence.get(t)];
            scaler[t] = scaler[t] + alpha[i][t];
            tempAlpha = 0;
        }

        scaler[t] = 1 / scaler[t];
        for(int i = 0; i < numStates; i++){
            alpha[i][t] = scaler[t] * alpha[i][t];
        }
    }

    System.out.format("scaler: ");
    for(int t = 0; t < time; t++){
        System.out.format("%f, ", scaler[t]);
    }
    System.out.print('\n');
    return alpha;
}

// Use backward algorithm to calculate beta for all states
public double[][] backwardAlgo(int time){
    double beta[][] = new double[2][time];

    // Intialize beta for current time
    for(int i = 0; i < numStates; i++){
        beta[i][time-1] = scaler[time-1];
    }

    // Use recursive method to calculate beta
    double tempBeta = 0;
    for(int t = time-2; t >= 0; t--){
        for(int i = 0; i < numStates; i++){
            for(int j = 0; j < numStates; j++){
                tempBeta = tempBeta + (stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
            }
            beta[i][t] = tempBeta;
            beta[i][t] = scaler[t] * beta[i][t];
            tempBeta = 0;
        }
    }

    return beta;
}

// Calculate the probability of emission sequence given the model (it is also the denominator to calculate gamma and digamma)
public double calcP(int t, double[][] alpha, double[][] beta){

    double p = 0;

    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            p = p + (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
        }
    }
    return p;
}

// Calculate digamma; i and j are both states
public double calcDigamma(double p, int t, int i, int j, double[][] alpha, double[][] beta){
    double digamma = (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]) / p;
    return digamma;
}

public void updatePi(double[][] gamma){
    for(int i = 0; i < numStates; i++){
        pi[i] = gamma[i][0];
    }
}

public void updateAll(){
    int time = eSequence.size();
    double alpha[][] = forwardAlgo(time);
    double beta[][] = backwardAlgo(time);
    double initialp = calcLogEProb(time);
    double nextState0, nextState1;

    double p = 0;
    double[][][] digamma = new double[numStates][numStates][time];
    double[][] gamma = new double[numStates][time];

    for(int t = 0; t < time-1; t++){
        p = calcP(t, alpha, beta);
        for(int i = 0; i < numStates; i++){
            gamma[i][t] = 0;
            for(int j = 0; j < numStates; j++){
                digamma[i][j][t] = calcDigamma(p, t, i, j, alpha, beta);
                gamma[i][t] = gamma[i][t] + digamma[i][j][t];
            }
        }
    }

    updatePi(gamma);
    updateA(digamma, gamma);
    updateB(gamma);

    alpha = forwardAlgo(time);
    double postp = calcLogEProb(time);
    improvementVar = postp - initialp;
}

// Update the state transition matrix
public void updateA(double[][][] digamma, double[][] gamma){
    int time = eSequence.size();
    double num = 0;
    double denom = 0;

    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            for(int t = 0; t < time-1; t++){
                num = num + digamma[i][j][t];
                denom = denom + gamma[i][t];
            }
            stateTransitionMatrix[i][j] = num/denom;
            num = 0;
            denom = 0;
        }
    }
}

public void updateB(double[][] gamma){
    int time = eSequence.size();
    double num = 0;
    double denom = 0;

    // k is an observation, j is a state, t is time
    for(int i = 0; i < numStates; i++){
        for(int k = 0; k < numEmissions; k++){
            for(int t = 0; t < time-1; t++){
                if( eSequence.get(t) == k) num = num + gamma[i][t];
                denom = denom + gamma[i][t];
            }
            emissionMatrix[i][k] = num/denom;
            num = 0;
            denom = 0;
        }
    }
}

public double calcLogEProb(int time){
    double logProb = 0;

    for(int t = 0; t < time; t++){
        logProb = logProb + Math.log(scaler[t]);
    }

    return -logProb;
}

public double calcNextState(int time, int state, double[][] gamma){
    double p = 0;
    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            p = p + gamma[i][time-2] * stateTransitionMatrix[i][j] * stateTransitionMatrix[j][state];
        }
    }

    return p;
}

// Print parameters
public void print(){
    System.out.println("Pi:");
    System.out.print('[');
    for(int i = 0; i < 2; i++){
        System.out.format("%f, ", pi[i]);
    }
    System.out.print(']');
    System.out.print('\n');

    System.out.println("A:");
    for(int i = 0; i < 2; i++){
        System.out.print('[');
        for(int j = 0; j < 2; j++){
            System.out.format("%f, ", stateTransitionMatrix[i][j]);
        }
        System.out.print(']');
        System.out.print('\n');
    }

    System.out.println("B:");
    for(int i = 0; i < 2; i++){
        System.out.print('[');
        for(int j = 0; j < 9; j++){
            System.out.format("%f, ", emissionMatrix[i][j]);
        }
        System.out.print(']');
        System.out.print('\n');
    }
    System.out.print('\n');
}

/* Generate sample data to test HMM training with the following params:
 * [ .3, .7 ]
 * [ .8, .2 ]                       [ .45 .1  .08 .05 .03 .02 .05 .2 .02 ]
 *                                  [ .36 .02 .06 .15 .04 .05  .2 .1 .02 ]
 * With these as observations:        {-10, -5, -3, -1, 0, 1, 3, 5, 10}
 */
public static int sampleDataGen(){
    double rand = 0;

    rand = Math.random();
    if(genState == 1){
        if(rand < .3) genState = 1;
        else genState = 2;

        rand = Math.random();
        if(rand < .45) return -10;
        else if(rand < .55) return -5;
        else if(rand < .63) return -3;
        else if(rand < .68) return -1;
        else if(rand < .71) return 0;
        else if(rand < .73) return 1;
        else if(rand < .78) return 3;
        else if(rand < .98) return 5;
        else return 10;
    }
    else {
        if(rand < .8) genState = 1;
        else genState = 2;

        rand = Math.random();
        if(rand < .36) return -10;
        else if(rand < .38) return -5;
        else if(rand < .44) return -3;
        else if(rand < .59) return -1;
        else if(rand < .63) return 0;
        else if(rand < .68) return 1;
        else if(rand < .88) return 3;
        else if(rand < .98) return 5;
        else return 10;
    }
}


public static void main(String[] args){
    HMModeltest test = new HMModeltest(1,9);
    test.print();

    System.out.print('\n');
    for(int i = 0; i < 20; i++){
        test.updateE(sampleDataGen());
    }

    test.updateAll();
    System.out.print('\n');
    test.print();
    System.out.print('\n');


    for(int i = 0; i < 10; i++){
        test.updateE(sampleDataGen());
    }
    test.updateAll();
    System.out.print('\n');
    test.print();
    System.out.print('\n');
}

}

My guess is that since the sample is too small, sometimes the probabilities don't exist for some observations. But it would be nice to have some confirmation.

  • Calculate the probabilities in log space then – Thomas Jungblut Feb 23 '14 at 22:54
  • My feeling is the same as of Thomas's. I don't see that your code is using log probabilities but as I understand that would be the standard way to approach a problem like this. So multiplication would become addition and you would not have to worry about running out of floating point precision. Also what do you mean by non existing probabilities? – jhegedus Feb 25 '14 at 11:17

1 Answers1

1

You could refer the "Scaling" section in Rabiner's paper, which solves the underflow problem.

You could also do the calculations in log space, that's what HTK and R do. Multiplication and division become addition and subtraction. For the other two, look at the LAdd/ LSub and logspace_add/logspace_sub functions in the respective toolkits.

The log-sum-exp trick might be helpful too.

max
  • 4,248
  • 2
  • 25
  • 38