Skip to content

Commit f49b118

Browse files
committed
handle indexing of GPU arrays
1 parent 9e1aa8c commit f49b118

File tree

6 files changed

+67
-5
lines changed

6 files changed

+67
-5
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
version = "1.44.0"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1718

1819
[compat]
20+
Adapt = "3.4.0"
1921
ChainRulesCore = "1.15.3"
2022
ChainRulesTestUtils = "1.5"
2123
Compat = "3.42.0, 4"
@@ -30,7 +32,6 @@ StructArrays = "0.6.11"
3032
julia = "1.6"
3133

3234
[extras]
33-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3435
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3536
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3637
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4041
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4142

4243
[targets]
43-
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
44+
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]

src/ChainRules.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module ChainRules
22

3+
using Adapt: adapt
34
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
45
using ChainRulesCore
56
using Compat
67
using Distributed
7-
using GPUArraysCore: AbstractGPUArrayStyle
8+
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
89
using IrrationalConstants: logtwo, logten
910
using LinearAlgebra
1011
using LinearAlgebra.BLAS

src/rulesets/Base/base.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
243243
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
244244
return y, map_pullback
245245
end
246+
247+
#####
248+
##### `task_local_storage`
249+
#####
250+
251+
# Called by `@allowscalar` from GPUArrays
252+
253+
ChainRules.@non_differentiable task_local_storage(key::Any)
254+
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)
255+
256+
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
257+
y, back = task_local_storage(key, value) do
258+
rrule_via_ad(config, body)
259+
end
260+
function task_local_storage_pullback(dy)
261+
dbody = only(back(dy))
262+
return (NoTangent(), dbody, NoTangent(), NoTangent())
263+
end
264+
return y, task_local_storage_pullback
265+
end

src/rulesets/Base/indexing.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...)
113113
end
114114
function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...)
115115
view(dx, inds...) .+= dy
116-
# For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting.
117-
# Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131
118116
return dx
119117
end
120118

@@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...)
134132
return z, ∇getindex_pullback
135133
end
136134

135+
# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
136+
# To avoid this, copy everything back to the CPU.
137+
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
138+
139+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...)
140+
view(dx, inds...) .+= Ref(dy)
141+
return dx
142+
end
143+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
144+
view(dx, inds...) .+= dy
145+
return dx
146+
end
147+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...)
148+
dx_cpu = adapt(Array, dx)
149+
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
150+
copyto!(dx, dx_cpu)
151+
return dx
152+
end
153+
137154
#####
138155
##### first, tail
139156
#####

test/rulesets/Base/indexing.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,25 @@
143143
test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false)
144144
test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false)
145145
end
146+
147+
@testset "GPU" begin
148+
x_23_gpu = jl(rand(2, 3))
149+
150+
# Scalar indexing, copied from: @macroexpand @allowscalar A[i]
151+
# Gives an error in Pkg.test, no idea why
152+
# y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed)
153+
# @test y1 == @allowscalar x_gpu[1]
154+
# bk1(1.0) # This is zero, because finite-differencing ignores the function
155+
# ... but this works, and calls the rule:
156+
# Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1]
157+
158+
y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+=
159+
@test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1])
160+
161+
y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU
162+
@test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why
163+
@test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0])
164+
end
146165
end
147166

148167
@testset "first & tail" begin
@@ -178,6 +197,7 @@ end
178197
end
179198

180199
@testset "unsafe_getindex" begin
200+
# In real life this is called only on some AbstractRanges, but easier to test on Array:
181201
test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3)
182202
test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3)
183203
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils
22

33
@nospecialize
44

5+
using Adapt
56
using Base.Broadcast: broadcastable
67
using ChainRules
78
using ChainRulesCore
89
using ChainRulesTestUtils
910
using ChainRulesTestUtils: rand_tangent, _fdm
1011
using FiniteDifferences
12+
using GPUArraysCore
13+
using JLArrays
1114
using LinearAlgebra
1215
using LinearAlgebra.BLAS
1316
using LinearAlgebra: dot

0 commit comments

Comments
 (0)