Skip to content

Check for repeated unthunking #262

Open
@mcabbott

Description

@mcabbott

test_rrule checks that the rule accepts thunks.

Should it also check that these aren't accidentally un-thunked more than once, by some passing in some @thunk (COUNT[+=1]; val) thing? That's a mistake although not a wrong-answers one, just slow. As seen in JuliaDiff/ChainRules.jl#670

julia> using ChainRulesCore, ChainRulesTestUtils

julia> plus(x,y) = x+y;

julia> unplus(z) = (NoTangent(), z, z);

julia> unplus(::AbstractThunk) = error("fails");

julia> ChainRulesCore.rrule(::typeof(plus), x, y) = x+y, unplus;

julia> test_rrule(plus, [1.0], [2.0])
test_rrule: plus on Vector{Float64},Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  fails

julia> unplus(z::AbstractThunk) = (NoTangent(), unthunk(@show z), @show unthunk(z))

julia> test_rrule(plus, [1.0], [2.0]);
z = Thunk(ChainRulesTestUtils.var"#61#65"{Vector{Float64}}([1.73]))
unthunk(z) = [1.73]
z = Thunk(ChainRulesTestUtils.var"#62#66"{Vector{Float64}}([1.73]))
unthunk(z) = [1.73]
Test Summary:                                       | Pass  Total  Time
test_rrule: plus on Vector{Float64},Vector{Float64} |    9      9  0.0s

Going further, should it also reject (or warn about) unplus(z::AbstractThunk) = (NoTangent(), z, z);? If what's returned contains exactly the same thunk it got, in more than one place, then it's likely that further steps will unthunk twice.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions