Skip to content

ChainRulesTestUtils crashes with a custom regularization function on a Flux Dense layer #266

Open
@TPU22

Description

@TPU22

I have been trying to write a custom reverse rule to a simple regularization function on a Flux Dense layer, and evaluate it with ChainRulesTestUtils. The function gradient from Zygote seems to work fine with the rules, but ChainRulesTestUtils crashes. The following code is executed just fine until the test_rrule calls. The first test_rrule tries to check whether the one-layer regularization function works, but instead it raises an error

Got exception outside of a @test 
MethodError: no method matching zero(::typeof(tanh))

The second test_rrule crashes with

Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}

Any idea what could be the issue here? A bug somewhere?

using ChainRulesCore
using Flux
using Random 
using ChainRulesTestUtils

Flux.trainable(nn::Dense) = (nn.weight, nn.bias,)

function weightregularization(nn::Dense)
    return sum((nn.weight).^2.0)

end

function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)  
    y = weightregularization(nn)
    project_w = ProjectTo(nn.weight)
    function weightregularization_pullback(ȳ)
        pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight),   bias=ZeroTangent(), σ= NoTangent())
        return NoTangent(), pullb
    end
    return y, weightregularization_pullback
end


function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
    a = 0.0
    for i in ch
        a = a + sum(i.weight.^2.0)
    end
    return a

end

function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T})  where T<:Tuple{Vararg{Dense}}
    y = totalregularization(ch)
    function totalregularization_pullback(ȳ)
        totalpullback = [] 
        N = length(ch)
        for i = 1:N
            project_w = ProjectTo(ch[i].weight)
            push!(totalpullback, (weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
        end      
        pullb = Tangent{Chain{T}}(layers=Tuple(totalpullback))
        return NoTangent(), pullb
    end
    return y, totalregularization_pullback
end




nn = Dense(randn(1,2), randn(1), tanh)
gr1 = gradient(weightregularization,nn)


l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr2 = gradient(totalregularization,ch)


test_rrule(weightregularization,nn) 
test_rrule(totalregularization,ch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions