Open
Description
I have been trying to write a custom reverse rule to a simple regularization function on a Flux Dense layer, and evaluate it with ChainRulesTestUtils. The function gradient from Zygote seems to work fine with the rules, but ChainRulesTestUtils crashes. The following code is executed just fine until the test_rrule calls. The first test_rrule tries to check whether the one-layer regularization function works, but instead it raises an error
Got exception outside of a @test
MethodError: no method matching zero(::typeof(tanh))
The second test_rrule crashes with
Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}
Any idea what could be the issue here? A bug somewhere?
using ChainRulesCore
using Flux
using Random
using ChainRulesTestUtils
Flux.trainable(nn::Dense) = (nn.weight, nn.bias,)
function weightregularization(nn::Dense)
return sum((nn.weight).^2.0)
end
function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)
y = weightregularization(nn)
project_w = ProjectTo(nn.weight)
function weightregularization_pullback(ȳ)
pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight), bias=ZeroTangent(), σ= NoTangent())
return NoTangent(), pullb
end
return y, weightregularization_pullback
end
function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
a = 0.0
for i in ch
a = a + sum(i.weight.^2.0)
end
return a
end
function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
y = totalregularization(ch)
function totalregularization_pullback(ȳ)
totalpullback = []
N = length(ch)
for i = 1:N
project_w = ProjectTo(ch[i].weight)
push!(totalpullback, (weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
end
pullb = Tangent{Chain{T}}(layers=Tuple(totalpullback))
return NoTangent(), pullb
end
return y, totalregularization_pullback
end
nn = Dense(randn(1,2), randn(1), tanh)
gr1 = gradient(weightregularization,nn)
l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr2 = gradient(totalregularization,ch)
test_rrule(weightregularization,nn)
test_rrule(totalregularization,ch)
Metadata
Metadata
Assignees
Labels
No labels