Developing a Single-Pass Weighted LogSumExp Function

3 min read

Recently I’ve been thinking about the LogSumExp trick since it is used in the integration step of nested sampling. I won’t go over too much of the math here, but the reason this trick exists is to greatly increase the numerical stability of the operation

$$ \log \sum_i \exp x_i $$

via the identity

$$ a + \log \sum_i \exp\left(x_i - a\right) $$

Naive implementations

In Julia we can implement a naive logsumexp with

logsumexp_naive(X) = log(sum(exp, X))

let’s test the numerical accuracy against Julia’s BigFloat for some very large numbers

using Random
rng = Random.seed!(55215)
X = 1000 .* rand(rng, 100)


# output

# output

Now let’s compare that to a version using the shift, which requires 2 passes through the collection X, first to find the maximum and again to accumulate the sum

function logsumexp_twopass(X)
    a = maximum(X)
    return a + log(sum(x -> exp(x - a), X))


# output

Let’s use BenchmarkTools.jl to do some timing tests

using BenchmarkTools
@btime logsumexp_naive($X)

# output
  437.848 ns (0 allocations: 0 bytes)
@btime logsumexp_twopass($X)

# output
  912.282 ns (0 allocations: 0 bytes)

we can see the extra pass over the collection almost exactly doubles our runtime.

Single-pass (streaming) implementation

From “Streaming Log-sum-exp Computation” by Sebastion Nowozin, we can actually find both the maximum and accumulate the sum with a single pass through the collection.

function logsumexp_onepass(X)
    a = -Inf
    r = zero(eltype(X))
    for x in X
        if x ≤ a
            # standard computation
            r += exp(x - a)
            # if new value is higher than current max
            r *= exp(a - x)
            r += one(x)
            a = x
    return a + log(r)
@btime logsumexp_onepass($X)

# output
  632.657 ns (0 allocations: 0 bytes)

so not quite as fast as the naive implementation, but still faster than the two-pass version.

Extending to weighted sum

The scipy implementation of logsumexp allows performing an extension of the logsumexp algorithm with a weighted sum-

$$ \log \sum_i{ w_i \exp x_i} $$

This is straightforward enough to implement using our single pass algorithm above

function logsumexp_onepass(X, w)
    a = -Inf
    r = zero(eltype(X))
    for (x, wi) in zip(X, w)
        if x ≤ a
            # standard computation
            r += wi * exp(x - a)
            # if new value is higher than current max
            r *= exp(a - x)
            r += wi
            a = x
    return a + log(r)

# when w = 1 it should be equivalent to logsumexp
logsumexp_onepass(X, ones(length(X)))

# output
w = rand(rng, length(X))
@btime logsumexp_onepass($X, $w)

# output
  777.162 ns (0 allocations: 0 bytes)

What’s missing?

While the weighted logsumexp function above works for many different types of iterators and arrays, it doesn’t support reducing over arbitrary dimensions in a multi-dimensional array-

function logsumexp_twopass(X, w; dims)
    a = maximum(X; dims=dims)
    r = sum(w .* exp.(X .- a); dims=dims)
    return a .+ log.(r)

logsumexp_twopass([X X], [w w]; dims=1)

# output
1×2 Matrix{Float64}:
 991.681  991.681

unfortunately, this method is quite slow

@btime logsumexp_twopass($([X X]), $([w w]); dims=1)

# output
  2.969 μs (7 allocations: 2.16 KiB)
1×2 Matrix{Float64}:
 991.681  991.681

This is ~3 times slower than the multi-dimensional array implementation of logsumexp in LogExpFunctions.jl, which does not support weights

using LogExpFunctions
@btime LogExpFunctions.logsumexp($([X X]); dims=1)

# output
  1.310 μs (2 allocations: 208 bytes)
1×2 Matrix{Float64}:
 992.485  992.485

I leave an open challenge to any readers who can get an implementation for a weighted logsumexp that matches the quality of the logsumexp implementation in LogExpFunctions.jl . The biggest issues with trying to transcribe the implementation is that LogExpFunctions.jl levies reduce, which doesn’t (currently) take multiple array arguments. Neither does mapslices, so a custom loop would have to be written over the array indices, as far as I can tell.