diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 5c39104..4a56577 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -190,7 +190,13 @@ function logsumexp(X) isempty(X) && return log(sum(X)) reduce(logaddexp, X) end -function logsumexp(X::AbstractArray{T}) where {T<:Real} +function logsumexp(X::AbstractArray{T}; dims=nothing) where {T<:Real} + dims isa Nothing && return _logsumexp(X) + isempty(X) && return log(zero(T)) + u = maximum(X; dims=dims) + return log.(sum(exp.(X .- u); dims=dims)) .+ u +end +function _logsumexp(X::AbstractArray{T}) where {T<:Real} isempty(X) && return log(zero(T)) u = maximum(X) isfinite(u) || return float(u) @@ -199,7 +205,6 @@ function logsumexp(X::AbstractArray{T}) where {T<:Real} end end - """ softmax!(r::AbstractArray, x::AbstractArray) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 02e00f4..a2da1f7 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -89,6 +89,9 @@ end @test logsumexp(arguments) ≡ result end end + + x = randn(5, 1) + @test logsumexp(x) == logsumexp(x; dims=1)[1] end @testset "softmax" begin