|
1 | 1 | from copy import copy
|
| 2 | +from math import log |
2 | 3 | from textwrap import dedent, indent
|
3 | 4 | from typing import Callable, Optional
|
4 | 5 |
|
|
8 | 9 | from numba import types
|
9 | 10 | from numba.extending import overload, overload_method, register_jitable
|
10 | 11 | from numba.np.random.distributions import random_beta, random_standard_gamma
|
| 12 | +from numba.np.random.generator_core import next_double |
11 | 13 | from numba.np.random.generator_methods import check_size, check_types, is_nonelike
|
12 | 14 |
|
13 | 15 | import aesara.tensor.random.basic as aer
|
@@ -373,3 +375,39 @@ def impl(inst, alphas, size=None):
|
373 | 375 | return random_dirichlet(inst.bit_generator, alphas, size)
|
374 | 376 |
|
375 | 377 | return impl
|
| 378 | + |
| 379 | + |
| 380 | +@register_jitable |
| 381 | +def random_gumbel(bitgen, loc, scale): |
| 382 | + """ |
| 383 | + This implementation is adapted from ``numpy/random/src/distributions/distributions.c``. |
| 384 | + """ |
| 385 | + while True: |
| 386 | + u = 1.0 - next_double(bitgen) |
| 387 | + if u < 1.0: |
| 388 | + return loc - scale * log(-log(u)) |
| 389 | + |
| 390 | + |
| 391 | +@overload_method(types.NumPyRandomGeneratorType, "gumbel") |
| 392 | +def NumPyRandomGeneratorType_gumbel(inst, loc=0.0, scale=1.0, size=None): |
| 393 | + check_types(loc, [types.Float, types.Integer, int, float], "loc") |
| 394 | + check_types(scale, [types.Float, types.Integer, int, float], "scale") |
| 395 | + |
| 396 | + if isinstance(size, types.Omitted): |
| 397 | + size = size.value |
| 398 | + |
| 399 | + if is_nonelike(size): |
| 400 | + |
| 401 | + def impl(inst, loc=0.0, scale=1.0, size=None): |
| 402 | + return random_gumbel(inst.bit_generator, loc, scale) |
| 403 | + |
| 404 | + else: |
| 405 | + check_size(size) |
| 406 | + |
| 407 | + def impl(inst, loc=0.0, scale=1.0, size=None): |
| 408 | + out = np.empty(size) |
| 409 | + for i in np.ndindex(size): |
| 410 | + out[i] = random_gumbel(inst.bit_generator, loc, scale) |
| 411 | + return out |
| 412 | + |
| 413 | + return impl |
0 commit comments