The logsumexp algorithm is a technique for calculating the following expression:
log( sum( exp( v ) ) )
,
where the exp
function is applied element-wise to the vector v
. The typical implementation of the algorithm is the straightforward identity:
max(v) + log( sum( exp( v - max(v) ) ) )
,
where max(v)
is the largest element in v
. The boost in numerical accuracy of this version comes from the ability to use log1p
when sum( exp( v - max(v) ) )
is between 1 and 2. The problem comes in when the number of elements in v
is large and none of them are dominant. This inspired me to make the following recursive algorithm for logsumexp
:
logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum( exp( lesserhalf(v) - logsumexp(greaterhalf(v)) ) ) )
,
with greaterhalf(v)
returning the half, rounded up, of elements from v
that are greater than the other half (lesserhalf(v)
returns the elements greaterhalf
doesn't). The recursion is terminated when there is only one element in v
, where logsumexp(v) = v
. There are, of course, a lot of optimizations that can be done including: sorting the elements of v
once, caching exp(v)
, and switching to a loop from recursion.
Is it possible to demonstrate that this recursive based algorithm, which can have a lot of optimizations, is numerically optimal in some sense? Specifically, that it minimizes rounding errors and only underflows/overflows when the arbitrary precision calculation would underflow/overflow when converting to finite precision at the end.
Here is a concrete implementation in C that makes a performance increasing trade-off (logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum(exp(lesserhalf(v))) / sum(exp(greaterhalf(v))) )
):
#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#define MAX(a, b) \
({ __typeof__ (a) x = (a); \
__typeof__ (b) y = (b); \
x > y ? x : y; })
#define MIN(a, b) \
({ __typeof__ (a) x = (a); \
__typeof__ (b) y = (b); \
x <= y ? x : y; })
//Swap the normal comparison order to get a decreasing sort
int doubleComp( const void *p1, const void *p2 ) {
double *d1 = (double *)p1;
double *d2 = (double *)p2;
if (*d1 < *d2) return -1;
if (*d1 > *d2) return 1;
if (*d1 == *d2) return 0;
return 1;//Moves NANs to the end of the list
}
double LogSumExp( double *restrict Inputs, size_t lenInputs ) {
double biggestInput, result;
size_t i;
{
double *curInput;
double *endInputs = Inputs + lenInputs;
int numinfs = 0;
//Preflight for NANs, infs
for (curInput = Inputs; curInput < endInputs; ++curInput) {
if ( isnan(*curInput) ) {
return NAN;
}
if (isinf(*curInput) && *curInput > 0.0) {
++numinfs;
}
}
if (numinfs > 0) {
return INFINITY;
}
}
if ( lenInputs == 2 ) {
biggestInput = MAX( Inputs[0], Inputs[1] );
result = biggestInput + log1p( exp( MIN(Inputs[0], Inputs[1]) - biggestInput ) );
}
else if ( lenInputs == 3 ) {
double sortedInputs[3];
double middleInput, smallestInput;
double lsebig2;
//Sort the inputs, without disturbing the actual Inputs
memcpy( sortedInputs, Inputs, 3 * sizeof (double) );
qsort( sortedInputs, 3, sizeof (double), doubleComp );
smallestInput = sortedInputs[0];
middleInput = sortedInputs[1];
biggestInput = sortedInputs[2];
lsebig2 = biggestInput + log1p( exp( middleInput - biggestInput ) );
result = lsebig2 + log1p( exp( smallestInput - lsebig2 ) );
}
else if ( lenInputs > 3 ) {
double *restrict sortedInputs, *restrict Sums, *restrict curSum;
size_t lenSums;
double bigpart;
size_t iSum, startpoint, stoppoint;
//Allocate needed memory
lenSums = (unsigned int) ceil( log2((double) lenInputs) ) + 1;
sortedInputs = (double *) malloc( lenInputs * sizeof (double) );
Sums = (double *) malloc( lenSums * sizeof (double) );
if ( sortedInputs == NULL || Sums == NULL ) {
fprintf(stderr, "Memory allocation failed in LogSumExp.\n");
abort();
}
//Sort the inputs, without disturbing the actual Inputs
memcpy( sortedInputs, Inputs, lenInputs * sizeof (double) );
qsort( sortedInputs, lenInputs, sizeof (double),
doubleComp );
//Subtract the biggest input to control possible overflow
biggestInput = sortedInputs[lenInputs - 1];
for ( i = 0; i < lenInputs; ++i ) {
sortedInputs[i] -= biggestInput;
}
//Produce the intermediate Sums
stoppoint = 0;
for ( iSum = 0; iSum < lenSums; iSum++ ) {
curSum = &( Sums[iSum] );
*curSum = 0.0;
startpoint = stoppoint;
stoppoint = lenInputs - startpoint;
stoppoint = ( stoppoint >> 1 );
stoppoint = MAX( stoppoint, 1 );
stoppoint += startpoint;
for ( i = startpoint; i < stoppoint; i++ ) {
*curSum += exp( sortedInputs[i] );
}
}
//Digest the Sums into results
result = 0.0;
for ( iSum = 0; iSum < lenSums - 1; iSum++ ) {
bigpart = 0.0;
for ( i = iSum + 1; i < lenSums; i++ ) {
bigpart += Sums[i];
}
if ( Sums[iSum] > 0.0 ) {
result += log1p( Sums[iSum] / bigpart );
}
}
free( Sums );
free( sortedInputs );
result += biggestInput;
}
else if ( lenInputs == 1 ) {
result = Inputs[0];
}
else {
result = NAN;
}
return result;
}
And a python implementation of the same (depends on numpy):
from numpy import *
def LogSumExp( inputs ):
if any( isnan(inputs) ) :
result = float("nan")
elif any(logical_and(isinf(inputs), inputs > 0.0)):
result = float("inf")
elif type(inputs) == type(float(1.0)) or len(inputs) == 1:
result = inputs
elif len(inputs) == 2:
smallval, bigval = ( min(inputs), max(inputs) )
result = bigval + log1p( exp( smallval - bigval ) )
elif len(inputs) > 2:
srtInputs = sort( inputs )
bigval = srtInputs[-1]
srtInputs -= bigval
expInputs = exp( srtInputs )
result = 0.0
startpoint = 0
endpoint = len(inputs) // 2
while startpoint < len(inputs) - 1:
smallpart = sum(expInputs[startpoint:endpoint])
if smallpart > 0.0:
result += log1p( smallpart / \
sum(expInputs[endpoint:]) )
startpoint = endpoint
endpoint = len(inputs) - startpoint
endpoint = endpoint // 2
endpoint = max( endpoint, 1 )
endpoint += startpoint
result += bigval
else:
result = float("nan")
return result