Open
Description
test_rrule
checks that the rule accepts thunks
.
Should it also check that these aren't accidentally un-thunked more than once, by some passing in some @thunk (COUNT[+=1]; val)
thing? That's a mistake although not a wrong-answers one, just slow. As seen in JuliaDiff/ChainRules.jl#670
julia> using ChainRulesCore, ChainRulesTestUtils
julia> plus(x,y) = x+y;
julia> unplus(z) = (NoTangent(), z, z);
julia> unplus(::AbstractThunk) = error("fails");
julia> ChainRulesCore.rrule(::typeof(plus), x, y) = x+y, unplus;
julia> test_rrule(plus, [1.0], [2.0])
test_rrule: plus on Vector{Float64},Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
fails
julia> unplus(z::AbstractThunk) = (NoTangent(), unthunk(@show z), @show unthunk(z))
julia> test_rrule(plus, [1.0], [2.0]);
z = Thunk(ChainRulesTestUtils.var"#61#65"{Vector{Float64}}([1.73]))
unthunk(z) = [1.73]
z = Thunk(ChainRulesTestUtils.var"#62#66"{Vector{Float64}}([1.73]))
unthunk(z) = [1.73]
Test Summary: | Pass Total Time
test_rrule: plus on Vector{Float64},Vector{Float64} | 9 9 0.0s
Going further, should it also reject (or warn about) unplus(z::AbstractThunk) = (NoTangent(), z, z);
? If what's returned contains exactly the same thunk it got, in more than one place, then it's likely that further steps will unthunk twice.