Skip to content

Commit 4b26667

Browse files
Add a Numba implementation for Generator.gumbel
1 parent 8da6847 commit 4b26667

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

aesara/link/numba/dispatch/random.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy
2+
from math import log
23
from textwrap import dedent, indent
34
from typing import Callable, Optional
45

@@ -8,6 +9,7 @@
89
from numba import types
910
from numba.extending import overload, overload_method, register_jitable
1011
from numba.np.random.distributions import random_beta, random_standard_gamma
12+
from numba.np.random.generator_core import next_double
1113
from numba.np.random.generator_methods import check_size, check_types, is_nonelike
1214

1315
import aesara.tensor.random.basic as aer
@@ -373,3 +375,39 @@ def impl(inst, alphas, size=None):
373375
return random_dirichlet(inst.bit_generator, alphas, size)
374376

375377
return impl
378+
379+
380+
@register_jitable
381+
def random_gumbel(bitgen, loc, scale):
382+
"""
383+
This implementation is adapted from ``numpy/random/src/distributions/distributions.c``.
384+
"""
385+
while True:
386+
u = 1.0 - next_double(bitgen)
387+
if u < 1.0:
388+
return loc - scale * log(-log(u))
389+
390+
391+
@overload_method(types.NumPyRandomGeneratorType, "gumbel")
392+
def NumPyRandomGeneratorType_gumbel(inst, loc=0.0, scale=1.0, size=None):
393+
check_types(loc, [types.Float, types.Integer, int, float], "loc")
394+
check_types(scale, [types.Float, types.Integer, int, float], "scale")
395+
396+
if isinstance(size, types.Omitted):
397+
size = size.value
398+
399+
if is_nonelike(size):
400+
401+
def impl(inst, loc=0.0, scale=1.0, size=None):
402+
return random_gumbel(inst.bit_generator, loc, scale)
403+
404+
else:
405+
check_size(size)
406+
407+
def impl(inst, loc=0.0, scale=1.0, size=None):
408+
out = np.empty(size)
409+
for i in np.ndindex(size):
410+
out[i] = random_gumbel(inst.bit_generator, loc, scale)
411+
return out
412+
413+
return impl

tests/link/numba/test_random.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
362362
"chi2",
363363
lambda *args: args,
364364
),
365-
pytest.param(
365+
(
366366
aer.gumbel,
367367
[
368368
set_test_value(
@@ -377,9 +377,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
377377
(2,),
378378
"gumbel_r",
379379
lambda *args: args,
380-
marks=pytest.mark.skip(
381-
reason="Not yet supported in Numba via `Generator`s"
382-
),
383380
),
384381
(
385382
aer.negative_binomial,

0 commit comments

Comments
 (0)