Skip to content

Jitted functions can't take keyword arguments named f in v0.4.36 #25329

Closed
@f0uriest

Description

@f0uriest

Description

For example:

def myfun(x, f):
    return x*f

jax.jit(myfun)(1.0, f=2.0)

Gives

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 4
      1 def myfun(x, f):
      2     return x*f
----> 4 jax.jit(myfun)(1.0, f=2.0)

    [... skipping hidden 11 frame]

File [~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py:74](http://127.0.0.1:8888/lab/tree/SCHOOL/Princeton/PPPL/DESC/local/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py#line=73), in flatten_fun(f, store, in_tree, *args_flat)
     71 @lu.transformation_with_aux2
     72 def flatten_fun(f, store, in_tree, *args_flat):
     73   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 74   ans = f(*py_args, **py_kwargs)
     75   ans, out_tree = tree_flatten(ans)
     76   store.store(out_tree)

TypeError: result_paths() got multiple values for argument 'f'

This only started happening in 0.4.36, I think the offending commit is 1c9b23c

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.36
jaxlib: 0.4.36
numpy:  1.24.4
python: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0]
device info: cpu-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='Discovery', release='5.15.0-125-generic', version='#135~20.04.1-Ubuntu SMP Mon Oct 7 13:56:22 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions