Closed
Description
Description
The bug can be reproduced using the following code:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
plt.plot(d)
The values at the end are obviously incorrect.
I have tested different batch and matrix sizes. The bug only happens when the full matrix size (in the example 1500000*20*20
) exceeds 2**29
. As also shown in the figure, the determinant values are wrong for batch index > 2**29 / (20*20)
.
I also tested on different devices. This bug happens on different types of GPUs while not on the CPU.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA A100 80GB PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')