I need to get the distribution of words for each topic found by Mallet in Java (not in the CLI as asked in how to get a probability distribution for a topic in mallet?). For an example of what I mean: Introduction to Latent Dirichlet Allocation:
Topic A: 30% broccoli, 15% bananas, 10% breakfast, 10% munching, … (at which point, you could interpret topic A to be about food)
Topic B: 20% chinchillas, 20% kittens, 20% cute, 15% hamster, … (at which point, you could interpret topic B to be about cute animals)
Mallet provides token "weights" per topic, and in http://comments.gmane.org/gmane.comp.ai.mallet.devel/2064 somebody attempted to write a method for getting the distribution of words per topic for Mallet.
I modified the method, so that all weights are divided by their sum as discussed in the mailing list above.
Does the following method (when added to ParallelTopicModel.java) correctly calculate the distribution of words per topic p(w|t) in Mallet?
/**
* Get the normalized topic word weights (weights sum up to 1.0)
* @param topic the topic
* @return the normalized topic word weights (weights sum up to 1.0)
*/
public ArrayList<double[]> getNormalizedTopicWordWeights(int topic) {
ArrayList<double[]> tokenWeights = new ArrayList<double[]>();
for (int type = 0; type < numTypes; type++) {
int[] topicCounts = typeTopicCounts[type];
double weight = beta;
int index = 0;
while (index < topicCounts.length && topicCounts[index] > 0) {
int currentTopic = topicCounts[index] & topicMask;
if (currentTopic == topic) {
weight += topicCounts[index] >> topicBits;
break;
}
index++;
}
double[] tokenAndWeight = { (double) type, weight };
tokenWeights.add(tokenAndWeight);
}
// normalize
double sum = 0;
// get the sum
for (double[] tokenAndWeight : tokenWeights) {
sum += tokenAndWeight[1];
}
// divide each element by the sum
ArrayList<double[]> normalizedTokenWeights = new ArrayList<double[]>();
for (double[] tokenAndWeight : tokenWeights) {
tokenAndWeight[1] = tokenAndWeight[1]/sum;
normalizedTokenWeights.add(tokenAndWeight);
}
return normalizedTokenWeights;
}