Closed
Description
Description
For example:
def myfun(x, f):
return x*f
jax.jit(myfun)(1.0, f=2.0)
Gives
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 4
1 def myfun(x, f):
2 return x*f
----> 4 jax.jit(myfun)(1.0, f=2.0)
[... skipping hidden 11 frame]
File [~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py:74](http://127.0.0.1:8888/lab/tree/SCHOOL/Princeton/PPPL/DESC/local/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py#line=73), in flatten_fun(f, store, in_tree, *args_flat)
71 @lu.transformation_with_aux2
72 def flatten_fun(f, store, in_tree, *args_flat):
73 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 74 ans = f(*py_args, **py_kwargs)
75 ans, out_tree = tree_flatten(ans)
76 store.store(out_tree)
TypeError: result_paths() got multiple values for argument 'f'
This only started happening in 0.4.36, I think the offending commit is 1c9b23c
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.36
jaxlib: 0.4.36
numpy: 1.24.4
python: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0]
device info: cpu-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='Discovery', release='5.15.0-125-generic', version='#135~20.04.1-Ubuntu SMP Mon Oct 7 13:56:22 UTC 2024', machine='x86_64')