2

The logsumexp algorithm is a technique for calculating the following expression:

log( sum( exp( v ) ) ),

where the exp function is applied element-wise to the vector v. The typical implementation of the algorithm is the straightforward identity:

max(v) + log( sum( exp( v - max(v) ) ) ),

where max(v) is the largest element in v. The boost in numerical accuracy of this version comes from the ability to use log1p when sum( exp( v - max(v) ) ) is between 1 and 2. The problem comes in when the number of elements in v is large and none of them are dominant. This inspired me to make the following recursive algorithm for logsumexp:

logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum( exp( lesserhalf(v) - logsumexp(greaterhalf(v)) ) ) ),

with greaterhalf(v) returning the half, rounded up, of elements from v that are greater than the other half (lesserhalf(v) returns the elements greaterhalf doesn't). The recursion is terminated when there is only one element in v, where logsumexp(v) = v. There are, of course, a lot of optimizations that can be done including: sorting the elements of v once, caching exp(v), and switching to a loop from recursion.

Is it possible to demonstrate that this recursive based algorithm, which can have a lot of optimizations, is numerically optimal in some sense? Specifically, that it minimizes rounding errors and only underflows/overflows when the arbitrary precision calculation would underflow/overflow when converting to finite precision at the end.

Here is a concrete implementation in C that makes a performance increasing trade-off (logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum(exp(lesserhalf(v))) / sum(exp(greaterhalf(v))) )):

#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>

#define MAX(a, b) \
    ({ __typeof__ (a) x = (a); \
        __typeof__ (b) y = (b); \
        x > y ? x : y; })
#define MIN(a, b) \
    ({ __typeof__ (a) x = (a); \
        __typeof__ (b) y = (b); \
        x <= y ? x : y; })

//Swap the normal comparison order to get a decreasing sort
int doubleComp( const void *p1, const void *p2 ) {
    double *d1 = (double *)p1;
    double *d2 = (double *)p2;

    if (*d1 < *d2) return -1;
    if (*d1 > *d2) return 1;
    if (*d1 == *d2) return 0;
    return 1;//Moves NANs to the end of the list
}
double LogSumExp( double *restrict Inputs, size_t lenInputs ) {
    double biggestInput, result;
    size_t i;

    {
        double *curInput;
        double *endInputs = Inputs + lenInputs;
        int numinfs = 0;
    
        //Preflight for NANs, infs
        for (curInput = Inputs; curInput < endInputs; ++curInput) {
            if ( isnan(*curInput) ) {
                return NAN;
            }
            if (isinf(*curInput) && *curInput > 0.0) {
                ++numinfs;
            }
        }
    
        if (numinfs > 0) {
            return INFINITY;
        }
    }

    if ( lenInputs == 2 ) {
        biggestInput = MAX( Inputs[0], Inputs[1] );
        result = biggestInput + log1p( exp( MIN(Inputs[0], Inputs[1]) - biggestInput ) );
    }

    else if ( lenInputs == 3 ) {
        double sortedInputs[3];
        double middleInput, smallestInput;
        double lsebig2;
    
        //Sort the inputs, without disturbing the actual Inputs
        memcpy( sortedInputs, Inputs, 3 * sizeof (double) );
        qsort( sortedInputs, 3, sizeof (double), doubleComp );
    
        smallestInput = sortedInputs[0];
        middleInput = sortedInputs[1];
        biggestInput = sortedInputs[2];
    
        lsebig2 = biggestInput + log1p( exp( middleInput - biggestInput ) );
    
        result = lsebig2 + log1p( exp( smallestInput - lsebig2 ) );
    
    }

    else if ( lenInputs > 3 ) {
        double *restrict sortedInputs, *restrict Sums, *restrict curSum;
        size_t lenSums;
        double bigpart;
        size_t iSum, startpoint, stoppoint;

        //Allocate needed memory
        lenSums = (unsigned int) ceil( log2((double) lenInputs) ) + 1;
        sortedInputs = (double *) malloc( lenInputs * sizeof (double) );
        Sums = (double *) malloc( lenSums * sizeof (double) );
        if ( sortedInputs == NULL || Sums == NULL ) {
            fprintf(stderr, "Memory allocation failed in LogSumExp.\n");
            abort();
        }

        //Sort the inputs, without disturbing the actual Inputs
        memcpy( sortedInputs, Inputs, lenInputs * sizeof (double) );
        qsort( sortedInputs, lenInputs, sizeof (double),
            doubleComp );

        //Subtract the biggest input to control possible overflow
        biggestInput = sortedInputs[lenInputs - 1];
        for ( i = 0; i < lenInputs; ++i ) {
            sortedInputs[i] -= biggestInput;
        }

        //Produce the intermediate Sums
        stoppoint = 0;
        for ( iSum = 0; iSum < lenSums; iSum++ ) {
            curSum = &( Sums[iSum] );
            *curSum = 0.0;

            startpoint = stoppoint;
            stoppoint = lenInputs - startpoint;
            stoppoint = ( stoppoint >> 1 );
            stoppoint = MAX( stoppoint, 1 );
            stoppoint += startpoint;

            for ( i = startpoint; i < stoppoint; i++ ) {
                *curSum += exp( sortedInputs[i] );
            }
        }

        //Digest the Sums into results
        result = 0.0;
        for ( iSum = 0; iSum < lenSums - 1; iSum++ ) {
            bigpart = 0.0;
            for ( i = iSum + 1; i < lenSums; i++ ) {
                bigpart += Sums[i];
            }

            if ( Sums[iSum] > 0.0 ) {
                result += log1p( Sums[iSum] / bigpart );
            }
        }

        free( Sums );
        free( sortedInputs );

        result += biggestInput;
    }

    else if ( lenInputs == 1 ) {
        result = Inputs[0];
    }

    else {
        result = NAN;
    }

    return result;
}

And a python implementation of the same (depends on numpy):

from numpy import *

def LogSumExp( inputs ):
    if any( isnan(inputs) ) :
        result = float("nan")

    elif any(logical_and(isinf(inputs), inputs > 0.0)):
        result = float("inf")

    elif type(inputs) == type(float(1.0)) or len(inputs) == 1:
        result = inputs

    elif len(inputs) == 2:
        smallval, bigval = ( min(inputs), max(inputs) )
        result = bigval + log1p( exp( smallval - bigval ) )

    elif len(inputs) > 2:
        srtInputs = sort( inputs )
        bigval = srtInputs[-1]
        srtInputs -= bigval
        expInputs = exp( srtInputs )

        result = 0.0
        startpoint = 0
        endpoint = len(inputs) // 2
        while startpoint < len(inputs) - 1:
            smallpart = sum(expInputs[startpoint:endpoint])
            if smallpart > 0.0:
                result += log1p( smallpart / \
                                 sum(expInputs[endpoint:]) )

            startpoint = endpoint
            endpoint = len(inputs) - startpoint
            endpoint = endpoint // 2
            endpoint = max( endpoint, 1 )
            endpoint += startpoint

        result += bigval

    else:
        result = float("nan")

    return result
Sean Lake
  • 578
  • 2
  • 8
  • 14
  • 1
    I can't really help you to formally analyze your algorithm, but I do have some comments. There already exists implementations of `logsumexp` in [Python](http://docs.scipy.org/doc/scipy/reference/generated/scipy.misc.logsumexp.html) and in [C](https://github.com/rmcgibbo/logsumexp/blob/master/src/logsumexp.c). A good start would be to compare the runtime, precision and memory usage of your implementation against those, using some known test cases. – Joey Dumont Sep 20 '16 at 18:43

0 Answers0