Skip to content

Commit f18e7ae

Browse files
committed
Add GumbelRV JAX implementation
1 parent 77dc5bc commit f18e7ae

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

aesara/link/jax/dispatch/random.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def sample_fn(rng, size, dtype, *parameters):
110110

111111

112112
@jax_sample_fn.register(aer.CauchyRV)
113+
@jax_sample_fn.register(aer.GumbelRV)
113114
@jax_sample_fn.register(aer.LaplaceRV)
114115
@jax_sample_fn.register(aer.LogisticRV)
115116
@jax_sample_fn.register(aer.NormalRV)

tests/link/jax/test_random.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,22 @@ def test_random_updates(rng_ctor):
123123
"gamma",
124124
lambda a, b: (a, 0.0, b),
125125
),
126+
(
127+
aer.gumbel,
128+
[
129+
set_test_value(
130+
at.lvector(),
131+
np.array([1, 2], dtype=np.int64),
132+
),
133+
set_test_value(
134+
at.dscalar(),
135+
np.array(1.0, dtype=np.float64),
136+
),
137+
],
138+
(2,),
139+
"gumbel_r",
140+
lambda *args: args,
141+
),
126142
(
127143
aer.laplace,
128144
[

0 commit comments

Comments
 (0)