Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log1pmx and logmxp1 cannot be differentiated by ForwardDiff #44

Closed
simsurace opened this issue Apr 27, 2022 · 7 comments · Fixed by #45 or JuliaDiff/DiffRules.jl#82
Closed

log1pmx and logmxp1 cannot be differentiated by ForwardDiff #44

simsurace opened this issue Apr 27, 2022 · 7 comments · Fixed by #45 or JuliaDiff/DiffRules.jl#82

Comments

@simsurace
Copy link
Contributor

using LogExpFunctions

using ForwardDiff
using Zygote
using Test

log1mexpm(x) = log1mexp(-x)
log2mexpm(x) = log2mexp(-x)

function test_ad_vs_zygote(AD)
    @testset verbose = true "LogExpFunctions" begin
        @testset "Single-argument functions" begin
            @testset "$f" for f in (                      
                log1pmx,
                logexpm1,
                logsumexp,
                xexpx,
                log1psq,
                logistic,
                invsoftplus,
                log2mexpm,
                logit,
                log1mexpm,
                logmxp1,
                xlogx,
                log1pexp,
                logcosh
            )
                for _ in 1:100
                    par = rand()
                    @test AD.derivative(f, par)  only(only(Zygote.gradient(f  only, [par])))
                end
            end
        end
        @testset "Two-argument functions" begin
            @testset "$f" for f in (                      
                xexpy,
                xlogy,
                xlog1py,
                logaddexp,
                logsubexp,
            )
                for _ in 1:100
                    par = rand(2)
                    @test AD.gradient(x -> f(x...), par)  only(Zygote.gradient(x -> f(x...), par))
                end
            end
        end
        @testset "Vector-argument functions" begin
                @testset "$f" for f in (                      
                softmax,
            )
                for _ in 1:100
                    par = rand(3)
                    @test AD.jacobian(f, par)  only(Zygote.jacobian(f, par))
                end
            end
        end
    end
    return nothing
end

test_ad_vs_zygote(ForwardDiff)

yields

Test Summary:               | Pass  Error  Total
LogExpFunctions             | 1800    200   2000
  Single-argument functions | 1200    200   1400
    log1pmx                 |         100    100
    logexpm1                |  100           100
    logsumexp               |  100           100
    xexpx                   |  100           100
    log1psq                 |  100           100
    logistic                |  100           100
    logexpm1                |  100           100
    log2mexpm               |  100           100
    logit                   |  100           100
    log1mexpm               |  100           100
    logmxp1                 |         100    100
    xlogx                   |  100           100
    log1pexp                |  100           100
    logcosh                 |  100           100
  Two-argument functions    |  500           500
  Vector-argument functions |  100           100
ERROR: Some tests did not pass: 1800 passed, 0 failed, 200 errored, 0 broken.
simsurace added a commit to simsurace/LogExpFunctions.jl that referenced this issue Apr 27, 2022
This looks like it would make sense, given that the other functions dispatch on `Real` argument, and it also fixes JuliaStats#44.
@devmotion
Copy link
Member

I assume they are missing in DiffRules?

@simsurace
Copy link
Contributor Author

#45 is sufficient to fix this.

@devmotion
Copy link
Member

It will not use the optimized Float64 implementations though which would be used if rules are added in DiffRules. Hence for ForwardDiff support #45 is a bit suboptimal.

@simsurace
Copy link
Contributor Author

Sure. However, the above tests pass, i.e. using there are no significant differences with Zygote. So #45 is an intermediate, if suboptimal fix. I suppose that defining a diffrule would require one to write an optimized version of the derivative as well, i.e. to calculate the analytical derivative, and try to find an optimal implementation.

In any case, when someone will have the time and capacity to do this, it will not conflict with the fallback in #45. So the ForwardDiff support is not conflicting with merging #45, is it?

@devmotion
Copy link
Member

I suppose that defining a diffrule would require one to write an optimized version of the derivative as well, i.e. to calculate the analytical derivative, and try to find an optimal implementation.

IMO the initial implementation does not have to be hyperoptimized. My point was that the computation of the primal will be suboptimal since it will use the generic fallback instead of the Float64 method. I did not want to refer to the computation of the derivative (even though possibly it can be optimized).

The PR does not conflict with such DiffRules additions but in my opinion it's not the correct fix for this issue here, and hence the motivation is a bit different.

Regardless of that, I think it is reasonable to extend them to ::Real. My only concern is that the proposed implementation might not be sufficiently accurate for Float32 and Float16 for which we provide optimized implementations for other functions. It would be good to check that it is sufficiently accurate for these types, and in particular the numerically problematic inputs.

@simsurace
Copy link
Contributor Author

I'm working on a PR to DiffRules.jl already.

@devmotion
Copy link
Member

Will be released in DiffRules 1.11.0: JuliaRegistries/General#59302

devmotion added a commit that referenced this issue May 2, 2022
* Add fallbacks for `log1pmx` and `logmxp1`

This looks like it would make sense, given that the other functions dispatch on `Real` argument, and it also fixes #44.

* Use less naive heuristic

Co-authored-by: David Widmann <[email protected]>

* Add some tests

* Bump version

Co-authored-by: David Widmann <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants