Developing a Single-Pass Weighted LogSumExp Function
Recently I’ve been thinking about the LogSumExp trick since it is used in the integration step of nested sampling. I won’t go over too much of the math here, but the reason this trick exists is to greatly increase the numerical stability of the operation
$$ \log \sum_i \exp x_i $$
via the identity
$$ a + \log \sum_i \exp\left(x_i - a\right) $$
Naive implementations
In Julia we can implement a naive logsumexp with
logsumexp_naive(X) = log(sum(exp, X))
let’s test the numerical accuracy against Julia’s BigFloat
for some very large numbers
using Random
rng = Random.seed!(55215)
X = 1000 .* rand(rng, 100)
logsumexp_naive(X)
# output
Inf
logsumexp_naive(big.(X))
# output
992.4854574035180795285906086509468594208066712019734540787001118090835360846258
Now let’s compare that to a version using the shift, which requires 2 passes through the collection X
, first to find the maximum and again to accumulate the sum
function logsumexp_twopass(X)
a = maximum(X)
return a + log(sum(x -> exp(x - a), X))
end
logsumexp_twopass(X)
# output
992.4854574035181
Let’s use BenchmarkTools.jl
to do some timing tests
using BenchmarkTools
@btime logsumexp_naive($X)
# output
437.848 ns (0 allocations: 0 bytes)
Inf
@btime logsumexp_twopass($X)
# output
912.282 ns (0 allocations: 0 bytes)
992.4854574035181
we can see the extra pass over the collection almost exactly doubles our runtime.
Single-pass (streaming) implementation
From “Streaming Log-sum-exp Computation” by Sebastion Nowozin, we can actually find both the maximum
and accumulate the sum with a single pass through the collection.
function logsumexp_onepass(X)
a = -Inf
r = zero(eltype(X))
for x in X
if x ≤ a
# standard computation
r += exp(x - a)
else
# if new value is higher than current max
r *= exp(a - x)
r += one(x)
a = x
end
end
return a + log(r)
end
@btime logsumexp_onepass($X)
# output
632.657 ns (0 allocations: 0 bytes)
992.4854574035181
so not quite as fast as the naive implementation, but still faster than the two-pass version.
Extending to weighted sum
The scipy implementation of logsumexp allows performing an extension of the logsumexp
algorithm with a weighted sum-
$$ \log \sum_i{ w_i \exp x_i} $$
This is straightforward enough to implement using our single pass algorithm above
function logsumexp_onepass(X, w)
a = -Inf
r = zero(eltype(X))
for (x, wi) in zip(X, w)
if x ≤ a
# standard computation
r += wi * exp(x - a)
else
# if new value is higher than current max
r *= exp(a - x)
r += wi
a = x
end
end
return a + log(r)
end
# when w = 1 it should be equivalent to logsumexp
logsumexp_onepass(X, ones(length(X)))
# output
992.4854574035181
w = rand(rng, length(X))
@btime logsumexp_onepass($X, $w)
# output
777.162 ns (0 allocations: 0 bytes)
991.6805331462472
What’s missing?
While the weighted logsumexp
function above works for many different types of iterators and arrays, it doesn’t support reducing over arbitrary dimensions in a multi-dimensional array-
function logsumexp_twopass(X, w; dims)
a = maximum(X; dims=dims)
r = sum(w .* exp.(X .- a); dims=dims)
return a .+ log.(r)
end
logsumexp_twopass([X X], [w w]; dims=1)
# output
1×2 Matrix{Float64}:
991.681 991.681
unfortunately, this method is quite slow
@btime logsumexp_twopass($([X X]), $([w w]); dims=1)
# output
2.969 μs (7 allocations: 2.16 KiB)
1×2 Matrix{Float64}:
991.681 991.681
This is ~3 times slower than the multi-dimensional array implementation of logsumexp
in LogExpFunctions.jl, which does not support weights
using LogExpFunctions
@btime LogExpFunctions.logsumexp($([X X]); dims=1)
# output
1.310 μs (2 allocations: 208 bytes)
1×2 Matrix{Float64}:
992.485 992.485
I leave an open challenge to any readers who can get an implementation for a weighted logsumexp
that matches the quality of the logsumexp
implementation in LogExpFunctions.jl . The biggest issues with trying to transcribe the implementation is that LogExpFunctions.jl levies reduce
, which doesn’t (currently) take multiple array arguments. Neither does mapslices
, so a custom loop would have to be written over the array indices, as far as I can tell.