Skip to content

Commit caaabf2

Browse files
committed
Add StandardNormalRV JAX implementation
1 parent f18e7ae commit caaabf2

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

aesara/link/jax/dispatch/random.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def sample_fn(rng, size, dtype, *parameters):
114114
@jax_sample_fn.register(aer.LaplaceRV)
115115
@jax_sample_fn.register(aer.LogisticRV)
116116
@jax_sample_fn.register(aer.NormalRV)
117+
@jax_sample_fn.register(aer.StandardNormalRV)
117118
def jax_sample_fn_loc_scale(op):
118119
"""JAX implementation of random variables in the loc-scale families.
119120

tests/link/jax/test_random.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ def test_random_updates(rng_ctor):
221221
"randint",
222222
lambda *args: args,
223223
),
224+
(
225+
aer.standard_normal,
226+
[],
227+
(2,),
228+
"norm",
229+
lambda *args: args,
230+
),
224231
(
225232
aer.t,
226233
[

0 commit comments

Comments
 (0)