Skip to content

Wrong determinant results for large batch #24843

Closed
@ChenAo-Phys

Description

@ChenAo-Phys

Description

The bug can be reproduced using the following code:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
plt.plot(d)

det

The values at the end are obviously incorrect.

I have tested different batch and matrix sizes. The bug only happens when the full matrix size (in the example 1500000*20*20) exceeds 2**29. As also shown in the figure, the determinant values are wrong for batch index > 2**29 / (20*20).

I also tested on different devices. This bug happens on different types of GPUs while not on the CPU.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA A100 80GB PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions