-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Hello,
I tried importing stories using pip install stories-jax
in a new conda environment (tried multiple times on multiple machines). Seems like KeyArray is not found in jax._src.random. Maybe caused by a JAX version issue?
import stories
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[3], line 1
----> 1 import stories
File [/opt/conda/lib/python3.12/site-packages/stories/__init__.py:1](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/stories/__init__.py#line=0)
----> 1 from .spacetime import SpaceTime
File [/opt/conda/lib/python3.12/site-packages/stories/spacetime.py:13](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/stories/spacetime.py#line=12)
11 from flax.training.early_stopping import EarlyStopping
12 from jax.random import PRNGKey
---> 13 from jax._src.random import KeyArray
14 import uuid
15 from orbax.checkpoint.args import StandardRestore
ImportError: cannot import name 'KeyArray' from 'jax._src.random' ([/opt/conda/lib/python3.12/site-packages/jax/_src/random.py](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/jax/_src/random.py))
by JAX changelog (https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0424-feb-6-2024)
A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}api-compatibility). This includes:
[...]
- From {mod}jax.random: PRNGKeyArray, KeyArray, default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.
Best Regards,
Marco Uderzo
Metadata
Metadata
Assignees
Labels
No labels