Skip to content

Basic ImportError: cannot import name 'KeyArray' from 'jax._src.random' #1

@marcouderzo

Description

@marcouderzo

Hello,
I tried importing stories using pip install stories-jax in a new conda environment (tried multiple times on multiple machines). Seems like KeyArray is not found in jax._src.random. Maybe caused by a JAX version issue?

import stories

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[3], line 1
----> 1 import stories

File [/opt/conda/lib/python3.12/site-packages/stories/__init__.py:1](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/stories/__init__.py#line=0)
----> 1 from .spacetime import SpaceTime

File [/opt/conda/lib/python3.12/site-packages/stories/spacetime.py:13](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/stories/spacetime.py#line=12)
     11 from flax.training.early_stopping import EarlyStopping
     12 from jax.random import PRNGKey
---> 13 from jax._src.random import KeyArray
     14 import uuid
     15 from orbax.checkpoint.args import StandardRestore

ImportError: cannot import name 'KeyArray' from 'jax._src.random' ([/opt/conda/lib/python3.12/site-packages/jax/_src/random.py](https://192.33.153.24/opt/conda/lib/python3.12/site-packages/jax/_src/random.py))

by JAX changelog (https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0424-feb-6-2024)
A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}api-compatibility). This includes:
[...]

  • From {mod}jax.random: PRNGKeyArray, KeyArray, default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.

Best Regards,
Marco Uderzo

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions