Skip to content

Commit 697e7e4

Browse files
authored
MersenneTwister (#223)
* Nograd for MersenneTwister * MersenneTwister frule * Bumps patch version * Bump patch version
1 parent ce39b65 commit 697e7e4

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.3"
3+
version = "0.7.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

src/ChainRules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Reexport
66
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
77
using LinearAlgebra
88
using LinearAlgebra.BLAS
9+
using Random
910
using Requires
1011
using Statistics
1112

@@ -35,6 +36,8 @@ include("rulesets/LinearAlgebra/dense.jl")
3536
include("rulesets/LinearAlgebra/structured.jl")
3637
include("rulesets/LinearAlgebra/factorization.jl")
3738

39+
include("rulesets/Random/random.jl")
40+
3841
# Note: The following is only required because package authors sometimes do not
3942
# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
4043
# So we define them here for them.

src/rulesets/Random/random.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
frule(Δargs, ::typeof(MersenneTwister), args...) = MersenneTwister(args...), Zero()
2+
3+
function rrule(::typeof(MersenneTwister), args...)
4+
function MersenneTwister_rrule(ΔΩ)
5+
return (NO_FIELDS, map(_ -> Zero(), args)...)
6+
end
7+
return MersenneTwister(args...), MersenneTwister_rrule
8+
end

test/rulesets/Random/random.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@testset "random" begin
2+
@testset "MersenneTwister" begin
3+
@testset "no args" begin
4+
rng, dΩ = frule((5.0,), MersenneTwister)
5+
@test rng isa MersenneTwister
6+
@testisa Zero
7+
8+
rng, pb = rrule(MersenneTwister)
9+
@test rng isa MersenneTwister
10+
@test first(pb(10)) isa Zero
11+
end
12+
@testset "unary" begin
13+
rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123)
14+
@test rng isa MersenneTwister
15+
@testisa Zero
16+
17+
rng, pb = rrule(MersenneTwister, 123)
18+
@test rng isa MersenneTwister
19+
@test all(map(x -> x isa Zero, pb(10)))
20+
end
21+
end
22+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ println("Testing ChainRules.jl")
4141

4242
print(" ")
4343

44+
@testset "Random" begin
45+
include(joinpath("rulesets", "Random", "random.jl"))
46+
end
47+
48+
print(" ")
49+
4450
@testset "packages" begin
4551
include(joinpath("rulesets", "packages", "NaNMath.jl"))
4652
include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))

0 commit comments

Comments
 (0)