diff --git a/.gitignore b/.gitignore
index 4e62578a6..7a4e033c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -154,3 +154,4 @@ cython_debug/
jumanji_env/
**/outputs/
*.xml
+.sokoban_cache/
diff --git a/README.md b/README.md
index e70ba830a..4b2fb4cdf 100644
--- a/README.md
+++ b/README.md
@@ -47,6 +47,7 @@

+
@@ -112,6 +113,7 @@ problems.
| 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) |
| Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) |
| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pacman/) | [doc](https://instadeepai.github.io/jumanji/environments/pacman/)
+| 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |
Installation 🎬
diff --git a/docs/api/environments/sokoban.md b/docs/api/environments/sokoban.md
new file mode 100644
index 000000000..1554480fc
--- /dev/null
+++ b/docs/api/environments/sokoban.md
@@ -0,0 +1,9 @@
+::: jumanji.environments.routing.sokoban.env.Sokoban
+ selection:
+ members:
+ - __init__
+ - observation_spec
+ - action_spec
+ - reset
+ - step
+ - render
diff --git a/docs/env_anim/sokoban.gif b/docs/env_anim/sokoban.gif
new file mode 100644
index 000000000..fffcb85ff
Binary files /dev/null and b/docs/env_anim/sokoban.gif differ
diff --git a/docs/env_img/sokoban.png b/docs/env_img/sokoban.png
new file mode 100644
index 000000000..fbc25a80e
Binary files /dev/null and b/docs/env_img/sokoban.png differ
diff --git a/docs/environments/sokoban.md b/docs/environments/sokoban.md
new file mode 100644
index 000000000..41ca8be4b
--- /dev/null
+++ b/docs/environments/sokoban.md
@@ -0,0 +1,76 @@
+# Sokoban Environment 👾
+
+
+
+
+
+This is a Jax implementation of the _Sokoban_ puzzle, a dynamic box-pushing environment where the agent's goal is to place all boxes on their targets. This version follows the rules from the DeepMind paper on [Imagination Augmented Agents](https://arxiv.org/abs/1707.06203), with levels based on the Boxoban dataset from [Guez et al., 2018](https://github.com/deepmind/boxoban-levels)[[1]](#ref1). The graphical assets were taken from [gym-sokoban](https://github.com/mpSchrader/gym-sokoban) by Schrader, a diverse Sokoban library implementing many versions of the game in the OpenAI gym framework [[2]](#ref2).
+
+## Observation
+
+- `grid`: An Array (uint8) of shape `(10, 10, 2)`. It represents the variable grid (containing movable objects: boxes and the agent) and the fixed grid (containing fixed objects: walls and targets).
+- `step_count`: An Array (int32) of shape `()`, representing the current number of steps in the episode.
+
+## Object Encodings
+
+| Object | Encoding |
+|--------------|----------|
+| Empty Space | 0 |
+| Wall | 1 |
+| Target | 2 |
+| Agent | 3 |
+| Box | 4 |
+
+## Actions
+
+The agent's action space is an Array (int32) with potential values of `[0,1,2,3]` (corresponding to `[Up, Down, Left, Right]`). If the agent attempts to move into a wall, off the grid, or push a box into a wall or off the grid, the grid state remains unchanged; however, the step count is incremented by one. Chained box pushes are not allowed and will result in no action.
+
+## Reward
+
+The reward function comprises:
+- `-0.1` for each step taken in the environment.
+- `+1` for each box moved onto a target location and `-1` for each box moved off a target location.
+- `+10` upon successful placement of all four boxes on their targets.
+
+## Episode Termination
+
+The episode concludes when:
+- The step limit of 120 is reached.
+- All 4 boxes are placed on targets (i.e., the problem is solved).
+
+## Dataset
+
+The Boxoban dataset offers a collection of puzzle levels. Each level features four boxes and four targets. The dataset has three levels of difficulty: 'unfiltered', 'medium', and 'hard'.
+
+| Dataset Split | Number of Levels |
+|---------------|------------------|
+| Unfiltered (Training) | 900,000 |
+| Unfiltered (Validation) | 100,000 |
+| Unfiltered (Test) | 1000 |
+| Medium (Training) | 450,000 |
+| Medium (Validation) | 50,000 |
+| Hard | 3332 |
+
+
+The dataset generation procedure and more details can be found in Guez et al., 2018 [1].
+
+## Graphics
+
+| Type | Graphic |
+|------------------|----------------------------------------------------------------------------------------|
+| Wall |  |
+| Floor |  |
+| Target |  |
+| Box on Target |  |
+| Box Off Target |  |
+| Agent Off Target |  |
+| Agent On Target |  |
+
+## Registered Versions 📖
+
+- `Sokoban-v0`: Sokoban game with levels generated using DeepMind Boxoban dataset (unfiltered train).
+
+## References
+[1] Guez, A., Mirza, M., Gregor, K., Kabra, R., Racaniere, S., Weber, T., Raposo, D., Santoro, A., Orseau, L., Eccles, T., Wayne, G., Silver, D., Lillicrap, T., Valdes, V. (2018). An investigation of Model-free planning: boxoban levels. Available at [https://github.com/deepmind/boxoban-levels](https://github.com/deepmind/boxoban-levels)
+
+[2] Schrader, M. (2018). Gym-sokoban. Available at [https://github.com/mpSchrader/gym-sokoban](https://github.com/mpSchrader/gym-sokoban)
diff --git a/jumanji/__init__.py b/jumanji/__init__.py
index 11f15cb08..49b3fdcfd 100644
--- a/jumanji/__init__.py
+++ b/jumanji/__init__.py
@@ -73,7 +73,6 @@
kwargs={"generator": very_easy_sudoku_generator},
)
-
###
# Packing Environments
###
@@ -93,7 +92,6 @@
# Tetris - the game of tetris with a grid size of 10x10 and a time limit of 400.
register(id="Tetris-v0", entry_point="jumanji.environments:Tetris")
-
###
# Routing Environments
###
@@ -128,5 +126,7 @@
# TSP with 20 randomly generated cities and a dense reward function.
register(id="TSP-v1", entry_point="jumanji.environments:TSP")
+# Sokoban with deepmind dataset generator
+register(id="Sokoban-v0", entry_point="jumanji.environments:Sokoban")
# Pacman - minimal version of Atarti Pacman game
register(id="PacMan-v0", entry_point="jumanji.environments:PacMan")
diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py
index 444eaa616..239ef8f51 100644
--- a/jumanji/environments/__init__.py
+++ b/jumanji/environments/__init__.py
@@ -35,6 +35,7 @@
pac_man,
robot_warehouse,
snake,
+ sokoban,
tsp,
)
from jumanji.environments.routing.cleaner.env import Cleaner
@@ -46,6 +47,7 @@
from jumanji.environments.routing.pac_man.env import PacMan
from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse
from jumanji.environments.routing.snake.env import Snake
+from jumanji.environments.routing.sokoban.env import Sokoban
from jumanji.environments.routing.tsp.env import TSP
diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py
index ba8115c16..1f0a91a6f 100644
--- a/jumanji/environments/logic/game_2048/env.py
+++ b/jumanji/environments/logic/game_2048/env.py
@@ -66,7 +66,7 @@ class Game2048(Environment[State]):
```python
from jumanji.environments import Game2048
env = Game2048()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py
index 36970d7da..f20f05a88 100644
--- a/jumanji/environments/logic/graph_coloring/env.py
+++ b/jumanji/environments/logic/graph_coloring/env.py
@@ -73,7 +73,7 @@ class GraphColoring(Environment[State]):
```python
from jumanji.environments import GraphColoring
env = GraphColoring()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py
index a5d9e5f01..641ee48f9 100644
--- a/jumanji/environments/logic/minesweeper/env.py
+++ b/jumanji/environments/logic/minesweeper/env.py
@@ -78,7 +78,7 @@ class Minesweeper(Environment[State]):
```python
from jumanji.environments import Minesweeper
env = Minesweeper()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py
index bd01f9809..84a2dff44 100644
--- a/jumanji/environments/logic/rubiks_cube/env.py
+++ b/jumanji/environments/logic/rubiks_cube/env.py
@@ -72,7 +72,7 @@ class RubiksCube(Environment[State]):
```python
from jumanji.environments import RubiksCube
env = RubiksCube()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py
index 629da1d1a..8e8f5ca12 100644
--- a/jumanji/environments/logic/sudoku/env.py
+++ b/jumanji/environments/logic/sudoku/env.py
@@ -63,7 +63,7 @@ class Sudoku(Environment[State]):
```python
from jumanji.environments import Sudoku
env = Sudoku()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py
index 8cd6aa259..c3127c07e 100644
--- a/jumanji/environments/packing/bin_pack/env.py
+++ b/jumanji/environments/packing/bin_pack/env.py
@@ -103,7 +103,7 @@ class BinPack(Environment[State]):
```python
from jumanji.environments import BinPack
env = BinPack()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py
index 1af25af87..0e2421524 100644
--- a/jumanji/environments/packing/job_shop/env.py
+++ b/jumanji/environments/packing/job_shop/env.py
@@ -80,7 +80,7 @@ class JobShop(Environment[State]):
```python
from jumanji.environments import JobShop
env = JobShop()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py
index ce6c3838c..573b83b32 100644
--- a/jumanji/environments/packing/knapsack/env.py
+++ b/jumanji/environments/packing/knapsack/env.py
@@ -73,7 +73,7 @@ class Knapsack(Environment[State]):
```python
from jumanji.environments import Knapsack
env = Knapsack()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py
index 37e608663..8223225f1 100644
--- a/jumanji/environments/packing/tetris/env.py
+++ b/jumanji/environments/packing/tetris/env.py
@@ -66,7 +66,7 @@ class Tetris(Environment[State]):
```python
from jumanji.environments import Tetris
env = Tetris()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py
index f1e1410c1..6377bcb8f 100644
--- a/jumanji/environments/routing/cleaner/env.py
+++ b/jumanji/environments/routing/cleaner/env.py
@@ -71,7 +71,7 @@ class Cleaner(Environment[State]):
```python
from jumanji.environments import Cleaner
env = Cleaner()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py
index 7a139f31b..e76ba9da7 100644
--- a/jumanji/environments/routing/connector/env.py
+++ b/jumanji/environments/routing/connector/env.py
@@ -85,7 +85,7 @@ class Connector(Environment[State]):
```python
from jumanji.environments import Connector
env = Connector()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py
index 6af507519..ca6f76920 100644
--- a/jumanji/environments/routing/cvrp/env.py
+++ b/jumanji/environments/routing/cvrp/env.py
@@ -86,7 +86,7 @@ class CVRP(Environment[State]):
```python
from jumanji.environments import CVRP
env = CVRP()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py
index 4c56eaedf..d0045144c 100644
--- a/jumanji/environments/routing/maze/env.py
+++ b/jumanji/environments/routing/maze/env.py
@@ -68,7 +68,7 @@ class Maze(Environment[State]):
```python
from jumanji.environments import Maze
env = Maze()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py
index cac71ffe7..485fa87ad 100644
--- a/jumanji/environments/routing/mmst/env.py
+++ b/jumanji/environments/routing/mmst/env.py
@@ -117,6 +117,17 @@ class MMST(Environment[State]):
- INVALID_CHOICE = -1
- INVALID_TIE_BREAK = -2
- INVALID_ALREADY_TRAVERSED = -3
+
+ ```python
+ from jumanji.environments import MMST
+ env = MMST()
+ key = jax.random.PRNGKey(0)
+ state, timestep = jax.jit(env.reset)(key)
+ env.render(state)
+ action = env.action_spec().generate_value()
+ state, timestep = jax.jit(env.step)(state, action)
+ env.render(state)
+ ```
"""
def __init__(
diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py
index 23688829e..01aa933ce 100644
--- a/jumanji/environments/routing/multi_cvrp/env.py
+++ b/jumanji/environments/routing/multi_cvrp/env.py
@@ -64,6 +64,17 @@ class MultiCVRP(Environment[State]):
[1] Zhang et al. (2020). "Multi-Vehicle Routing Problems with Soft Time Windows: A
Multi-Agent Reinforcement Learning Approach".
+
+ ```python
+ from jumanji.environments import MultiCVRP
+ env = MultiCVRP()
+ key = jax.random.PRNGKey(0)
+ state, timestep = jax.jit(env.reset)(key)
+ env.render(state)
+ action = env.action_spec().generate_value()
+ state, timestep = jax.jit(env.step)(state, action)
+ env.render(state)
+ ```
"""
def __init__(
diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py
index 0231c3a94..fafd0cd20 100644
--- a/jumanji/environments/routing/snake/env.py
+++ b/jumanji/environments/routing/snake/env.py
@@ -81,7 +81,7 @@ class Snake(Environment[State]):
```python
from jumanji.environments import Snake
env = Snake()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/environments/routing/sokoban/__init__.py b/jumanji/environments/routing/sokoban/__init__.py
new file mode 100644
index 000000000..ff8aa3a7c
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from jumanji.environments.routing.sokoban.env import Sokoban
+from jumanji.environments.routing.sokoban.types import Observation, State
diff --git a/jumanji/environments/routing/sokoban/constants.py b/jumanji/environments/routing/sokoban/constants.py
new file mode 100644
index 000000000..782395e31
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/constants.py
@@ -0,0 +1,37 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import jax.numpy as jnp
+
+# Translating actions to coordinate changes
+MOVES = jnp.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
+NOOP = -1
+
+# Object encodings
+EMPTY = 0
+WALL = 1
+TARGET = 2
+AGENT = 3
+BOX = 4
+TARGET_AGENT = 5
+TARGET_BOX = 6
+
+# Environment Variables
+N_BOXES = 4
+GRID_SIZE = 10
+
+# Reward Function
+LEVEL_COMPLETE_BONUS = 10
+SINGLE_BOX_BONUS = 1
+STEP_BONUS = -0.1
diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py
new file mode 100644
index 000000000..c56fcbf89
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/env.py
@@ -0,0 +1,575 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional, Sequence, Tuple
+
+import chex
+import jax
+import jax.numpy as jnp
+import matplotlib.animation
+
+from jumanji import specs
+from jumanji.env import Environment
+from jumanji.environments.routing.sokoban.constants import (
+ AGENT,
+ BOX,
+ EMPTY,
+ GRID_SIZE,
+ MOVES,
+ N_BOXES,
+ NOOP,
+ TARGET,
+ TARGET_AGENT,
+ TARGET_BOX,
+ WALL,
+)
+from jumanji.environments.routing.sokoban.generator import (
+ Generator,
+ HuggingFaceDeepMindGenerator,
+)
+from jumanji.environments.routing.sokoban.reward import DenseReward, RewardFn
+from jumanji.environments.routing.sokoban.types import Observation, State
+from jumanji.environments.routing.sokoban.viewer import BoxViewer
+from jumanji.types import TimeStep, restart, termination, transition
+from jumanji.viewer import Viewer
+
+
+class Sokoban(Environment[State]):
+ """A JAX implementation of the 'Sokoban' game from deepmind.
+
+ - observation: `Observation`
+ - grid: jax array (uint8) of shape (num_rows, num_cols, 2)
+ Array that includes information about the agent, boxes, and
+ targets in the game.
+ - step_count: jax array (int32) of shape ()
+ current number of steps in the episode.
+
+ - action: jax array (int32) of shape ()
+ [0,1,2,3] -> [Up, Right, Down, Left].
+
+ - reward: jax array (float) of shape ()
+ A reward of 1.0 is given for each box placed on a target and -1
+ when removed from a target and -0.1 for each timestep.
+ 10 is awarded when all boxes are on targets.
+
+ - episode termination:
+ - if the time limit is reached.
+ - if all boxes are on targets.
+
+ - state: `State`
+ - key: jax array (uint32) of shape (2,) used for auto-reset
+ - fixed_grid: jax array (uint8) of shape (num_rows, num_cols)
+ array indicating the walls and targets in the level.
+ - variable_grid: jax array (uint8) of shape (num_rows, num_cols)
+ array indicating the current location of the agent and boxes.
+ - agent_location: jax array (int32) of shape (2,)
+ the agent's current location.
+ - step_count: jax array (int32) of shape ()
+ current number of steps in the episode.
+
+ ```python
+ from jumanji.environments import Sokoban
+ from jumanji.environments.routing.sokoban.generator import
+ HuggingFaceDeepMindGenerator,
+
+ env_train = Sokoban(
+ generator=HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-train",
+ proportion_of_files=1,
+ )
+ )
+
+ env_test = Sokoban(
+ generator=HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-test",
+ proportion_of_files=1,
+ )
+ )
+
+ # Train...
+
+ ```
+ key_train = jax.random.PRNGKey(0)
+ state, timestep = jax.jit(env_train.reset)(key_train)
+ env_train.render(state)
+ action = env_train.action_spec().generate_value()
+ state, timestep = jax.jit(env_train.step)(state, action)
+ env_train.render(state)
+ ```
+ """
+
+ def __init__(
+ self,
+ generator: Optional[Generator] = None,
+ reward_fn: Optional[RewardFn] = None,
+ viewer: Optional[Viewer] = None,
+ time_limit: int = 120,
+ ) -> None:
+ """
+ Instantiates a `Sokoban` environment with a specific generator,
+ time limit, and viewer.
+
+ Args:
+ generator: `Generator` whose `__call__` instantiates an environment
+ instance (an initial state). Implemented options are [`ToyGenerator`,
+ `DeepMindGenerator`, and `HuggingFaceDeepMindGenerator`].
+ Defaults to `HuggingFaceDeepMindGenerator` with
+ `dataset_name="unfiltered-train", proportion_of_files=1`.
+ time_limit: int, max steps for the environment, defaults to 120.
+ viewer: 'Viewer' object, used to render the environment.
+ If not provided, defaults to`BoxViewer`.
+ """
+
+ self.num_rows = GRID_SIZE
+ self.num_cols = GRID_SIZE
+ self.shape = (self.num_rows, self.num_cols)
+ self.time_limit = time_limit
+
+ self.generator = generator or HuggingFaceDeepMindGenerator(
+ "unfiltered-train",
+ proportion_of_files=1,
+ )
+
+ self._viewer = viewer or BoxViewer(
+ name="Sokoban",
+ grid_combine=self.grid_combine,
+ )
+ self.reward_fn = reward_fn or DenseReward()
+
+ def __repr__(self) -> str:
+ """
+ Returns a printable representation of the Sokoban environment.
+
+ Returns:
+ str: A string representation of the Sokoban environment.
+ """
+ return "\n".join(
+ [
+ "Bokoban environment:",
+ f" - num_rows: {self.num_rows}",
+ f" - num_cols: {self.num_cols}",
+ f" - time_limit: {self.time_limit}",
+ f" - generator: {self.generator}",
+ ]
+ )
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
+ """
+ Resets the environment by calling the instance generator for a
+ new instance.
+
+ Args:
+ key: random key used to sample new Sokoban problem.
+
+ Returns:
+ state: `State` object corresponding to the new state of the
+ environment after a reset.
+ timestep: `TimeStep` object corresponding the first timestep
+ returned by the environment after a reset.
+ """
+
+ generator_key, key = jax.random.split(key)
+
+ state = self.generator(generator_key)
+
+ timestep = restart(
+ self._state_to_observation(state),
+ extras=self._get_extras(state),
+ )
+
+ return state, timestep
+
+ def step(
+ self, state: State, action: chex.Array
+ ) -> Tuple[State, TimeStep[Observation]]:
+ """
+ Executes one timestep of the environment's dynamics.
+
+ Args:
+ state: 'State' object representing the current state of the
+ environment.
+ action: Array (int32) of shape ().
+ - 0: move up.
+ - 1: move down.
+ - 2: move left.
+ - 3: move right.
+
+ Returns:
+ state, timestep: next state of the environment and timestep to be
+ observed.
+ """
+
+ # switch to noop if action will have no impact on variable grid
+ action = self.detect_noop_action(
+ state.variable_grid, state.fixed_grid, action, state.agent_location
+ )
+
+ next_variable_grid, next_agent_location = jax.lax.cond(
+ jnp.all(action == NOOP),
+ lambda: (state.variable_grid, state.agent_location),
+ lambda: self.move_agent(state.variable_grid, action, state.agent_location),
+ )
+
+ next_state = State(
+ key=state.key,
+ fixed_grid=state.fixed_grid,
+ variable_grid=next_variable_grid,
+ agent_location=next_agent_location,
+ step_count=state.step_count + 1,
+ )
+
+ target_reached = self.level_complete(next_state)
+ time_limit_exceeded = next_state.step_count >= self.time_limit
+
+ done = jnp.logical_or(target_reached, time_limit_exceeded)
+
+ reward = jnp.asarray(self.reward_fn(state, action, next_state), float)
+
+ observation = self._state_to_observation(next_state)
+
+ extras = self._get_extras(next_state)
+
+ timestep = jax.lax.cond(
+ done,
+ lambda: termination(
+ reward=reward,
+ observation=observation,
+ extras=extras,
+ ),
+ lambda: transition(
+ reward=reward,
+ observation=observation,
+ extras=extras,
+ ),
+ )
+
+ return next_state, timestep
+
+ def observation_spec(self) -> specs.Spec[Observation]:
+ """
+ Returns the specifications of the observation of the `Sokoban`
+ environment.
+
+ Returns:
+ specs.Spec[Observation]: The specifications of the observations.
+ """
+ grid = specs.BoundedArray(
+ shape=(self.num_rows, self.num_cols, 2),
+ dtype=jnp.uint8,
+ minimum=0,
+ maximum=4,
+ name="grid",
+ )
+ step_count = specs.Array((), jnp.int32, "step_count")
+ return specs.Spec(
+ Observation,
+ "ObservationSpec",
+ grid=grid,
+ step_count=step_count,
+ )
+
+ def action_spec(self) -> specs.DiscreteArray:
+ """
+ Returns the action specification for the Sokoban environment.
+ There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].
+
+ Returns:
+ specs.DiscreteArray: Discrete action specifications.
+ """
+ return specs.DiscreteArray(4, name="action", dtype=jnp.int32)
+
+ def _state_to_observation(self, state: State) -> Observation:
+ """Maps an environment state to an observation.
+
+ Args:
+ state: `State` object containing the dynamics of the environment.
+
+ Returns:
+ The observation derived from the state.
+ """
+
+ total_grid = jnp.stack([state.variable_grid, state.fixed_grid], axis=-1)
+
+ return Observation(
+ grid=total_grid,
+ step_count=state.step_count,
+ )
+
+ def _get_extras(self, state: State) -> Dict:
+ """
+ Computes extras metrics to be returned within the timestep.
+
+ Args:
+ state: 'State' object representing the current state of the
+ environment.
+
+ Returns:
+ extras: Dict object containing current proportion of boxes on
+ targets and whether the problem is solved.
+ """
+ num_boxes_on_targets = self.reward_fn.count_targets(state)
+ total_num_boxes = N_BOXES
+ extras = {
+ "prop_correct_boxes": num_boxes_on_targets / total_num_boxes,
+ "solved": num_boxes_on_targets == 4,
+ }
+ return extras
+
+ def grid_combine(
+ self, variable_grid: chex.Array, fixed_grid: chex.Array
+ ) -> chex.Array:
+ """
+ Combines the variable grid and fixed grid into one single grid
+ representation of the current Sokoban state required for visual
+ representation of the Sokoban state. Takes care of two possible
+ overlaps of fixed and variable entries (an agent on a target or a box
+ on a target), introducing two additional encodings.
+
+ Args:
+ variable_grid: Array (uint8) of shape (num_rows, num_cols).
+ fixed_grid: Array (uint8) of shape (num_rows, num_cols).
+
+ Returns:
+ full_grid: Array (uint8) of shape (num_rows, num_cols, 2).
+ """
+
+ mask_target_agent = jnp.logical_and(
+ fixed_grid == TARGET,
+ variable_grid == AGENT,
+ )
+
+ mask_target_box = jnp.logical_and(
+ fixed_grid == TARGET,
+ variable_grid == BOX,
+ )
+
+ single_grid = jnp.where(
+ mask_target_agent,
+ TARGET_AGENT,
+ jnp.where(
+ mask_target_box,
+ TARGET_BOX,
+ jnp.maximum(variable_grid, fixed_grid),
+ ),
+ ).astype(jnp.uint8)
+
+ return single_grid
+
+ def level_complete(self, state: State) -> chex.Array:
+ """
+ Checks if the sokoban level is complete.
+
+ Args:
+ state: `State` object representing the current state of the environment.
+
+ Returns:
+ complete: Boolean indicating whether the level is complete
+ or not.
+ """
+ return self.reward_fn.count_targets(state) == N_BOXES
+
+ def check_space(
+ self,
+ grid: chex.Array,
+ location: chex.Array,
+ value: int,
+ ) -> chex.Array:
+ """
+ Checks if a specific location in the grid contains a given value.
+
+ Args:
+ grid: Array (uint8) shape (num_rows, num_cols) The grid to check.
+ location: Tuple size 2 of Array (int32) shape () containing the x
+ and y coodinate of the location to check in the grid.
+ value: int The value to look for.
+
+ Returns:
+ present: Array (bool) shape () indicating whether the location
+ in the grid contains the given value or not.
+ """
+
+ return grid[tuple(location)] == value
+
+ def in_grid(self, coordinates: chex.Array) -> chex.Array:
+ """
+ Checks if given coordinates are within the grid size.
+
+ Args:
+ coordinates: Array (uint8) shape (num_rows, num_cols) The
+ coordinates to check.
+ Returns:
+ in_grid: Array (bool) shape () Boolean indicating whether the
+ coordinates are within the grid.
+ """
+ return jnp.all((0 <= coordinates) & (coordinates < GRID_SIZE))
+
+ def detect_noop_action(
+ self,
+ variable_grid: chex.Array,
+ fixed_grid: chex.Array,
+ action: chex.Array,
+ agent_location: chex.Array,
+ ) -> chex.Array:
+ """
+ Masks actions to -1 that have no effect on the variable grid.
+ Determines if there is space in the destination square or if
+ there is a box in the destination square, it determines if the box
+ destination square is valid.
+
+ Args:
+ variable_grid: Array (uint8) shape (num_rows, num_cols).
+ fixed_grid Array (uint8) shape (num_rows, num_cols) .
+ action: Array (int32) shape () The action to check.
+
+ Returns:
+ updated_action: Array (int32) shape () The updated action after
+ detecting noop action.
+ """
+
+ new_location = agent_location + MOVES[action].squeeze()
+
+ valid_destination = self.check_space(
+ fixed_grid, new_location, WALL
+ ) | ~self.in_grid(new_location)
+
+ updated_action = jax.lax.select(
+ valid_destination,
+ jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
+ jax.lax.select(
+ self.check_space(variable_grid, new_location, BOX),
+ self.update_box_push_action(
+ fixed_grid,
+ variable_grid,
+ new_location,
+ action,
+ ),
+ action,
+ ),
+ )
+
+ return updated_action
+
+ def update_box_push_action(
+ self,
+ fixed_grid: chex.Array,
+ variable_grid: chex.Array,
+ new_location: chex.Array,
+ action: chex.Array,
+ ) -> chex.Array:
+ """
+ Masks actions to -1 if pushing the box is not a valid move. If it
+ would be pushed out of the grid or the resulting square
+ is either a wall or another box.
+
+ Args:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols) The fixed grid.
+ variable_grid: Array (uint8) shape (num_rows, num_cols) The
+ variable grid.
+ new_location: Array (int32) shape (2,) The new location of the agent.
+ action: Array (int32) shape () The action to be executed.
+
+ Returns:
+ updated_action: Array (int32) shape () The updated action after
+ checking if pushing the box is a valid move.
+ """
+
+ return jax.lax.select(
+ self.check_space(
+ variable_grid,
+ new_location + MOVES[action].squeeze(),
+ BOX,
+ )
+ | ~self.in_grid(new_location + MOVES[action].squeeze()),
+ jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
+ jax.lax.select(
+ self.check_space(
+ fixed_grid,
+ new_location + MOVES[action].squeeze(),
+ WALL,
+ ),
+ jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
+ action,
+ ),
+ )
+
+ def move_agent(
+ self,
+ variable_grid: chex.Array,
+ action: chex.Array,
+ current_location: chex.Array,
+ ) -> Tuple[chex.Array, chex.Array]:
+ """
+ Executes the movement of the agent specified by the action and
+ executes the movement of a box if present at the destination.
+
+ Args:
+ variable_grid: Array (uint8) shape (num_rows, num_cols)
+ action: Array (int32) shape () The action to take.
+ current_location: Array (int32) shape (2,)
+
+ Returns:
+ next_variable_grid: Array (uint8) shape (num_rows, num_cols)
+ next_location: Array (int32) shape (2,)
+ """
+
+ next_location = current_location + MOVES[action]
+ box_location = next_location + MOVES[action]
+
+ # remove agent from current location
+ next_variable_grid = variable_grid.at[tuple(current_location)].set(EMPTY)
+
+ # either move agent or move agent and box
+
+ next_variable_grid = jax.lax.select(
+ self.check_space(variable_grid, next_location, BOX),
+ next_variable_grid.at[tuple(next_location)]
+ .set(AGENT)
+ .at[tuple(box_location)]
+ .set(BOX),
+ next_variable_grid.at[tuple(next_location)].set(AGENT),
+ )
+
+ return next_variable_grid, next_location
+
+ def render(self, state: State) -> None:
+ """
+ Renders the current state of Sokoban.
+
+ Args:
+ state: 'State' object , the current state to be rendered.
+ """
+
+ self._viewer.render(state=state)
+
+ def animate(
+ self,
+ states: Sequence[State],
+ interval: int = 200,
+ save_path: Optional[str] = None,
+ ) -> matplotlib.animation.FuncAnimation:
+ """
+ Creates an animated gif of the Sokoban environment based on the
+ sequence of states.
+
+ Args:
+ states: Sequence of 'State' object
+ interval: int, The interval between frames in the animation.
+ Defaults to 200.
+ save_path: str The path where to save the animation. If not
+ provided, the animation is not saved.
+
+ Returns:
+ animation: 'matplotlib.animation.FuncAnimation'.
+ """
+ return self._viewer.animate(states, interval, save_path)
diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py
new file mode 100644
index 000000000..8c3d8da93
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/env_test.py
@@ -0,0 +1,217 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+
+import chex
+import jax
+import jax.numpy as jnp
+import pytest
+
+from jumanji.environments.routing.sokoban.constants import AGENT, BOX, TARGET, WALL
+from jumanji.environments.routing.sokoban.env import Sokoban
+from jumanji.environments.routing.sokoban.generator import (
+ DeepMindGenerator,
+ SimpleSolveGenerator,
+)
+from jumanji.environments.routing.sokoban.types import State
+from jumanji.testing.env_not_smoke import check_env_does_not_smoke
+from jumanji.types import TimeStep
+
+
+@pytest.fixture(scope="session")
+def sokoban() -> Sokoban:
+ env = Sokoban(
+ generator=DeepMindGenerator(
+ difficulty="unfiltered",
+ split="train",
+ proportion_of_files=0.005,
+ )
+ )
+ return env
+
+
+@pytest.fixture(scope="session")
+def sokoban_simple() -> Sokoban:
+ env = Sokoban(generator=SimpleSolveGenerator())
+ return env
+
+
+def test_sokoban__reset(sokoban: Sokoban) -> None:
+ chex.clear_trace_counter()
+ reset_fn = jax.jit(chex.assert_max_traces(sokoban.reset, n=1))
+ key = jax.random.PRNGKey(0)
+ state, timestep = reset_fn(key)
+ assert isinstance(timestep, TimeStep)
+ assert isinstance(state, State)
+ assert state.step_count == 0
+ assert timestep.observation.step_count == 0
+ key2 = jax.random.PRNGKey(1)
+ state2, timestep2 = reset_fn(key2)
+ assert not jnp.array_equal(state2.fixed_grid, state.fixed_grid)
+ assert not jnp.array_equal(state2.variable_grid, state.variable_grid)
+
+
+def test_sokoban__multi_step(sokoban: Sokoban) -> None:
+ """Validates the jitted step of the sokoban environment."""
+ chex.clear_trace_counter()
+ step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1))
+
+ # Repeat test for 5 different state initializations
+ for j in range(5):
+ step_count = 0
+ key = jax.random.PRNGKey(j)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban.reset(reset_key)
+
+ # Repeating random step 120 times
+ for _ in range(120):
+ action = jnp.array(random.randint(0, 4), jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ # Check step_count increases after each step
+ step_count += 1
+ assert state.step_count == step_count
+ assert timestep.observation.step_count == step_count
+
+ # Check that the fixed part of the state has not changed
+ assert jnp.array_equal(state.fixed_grid, state.fixed_grid)
+
+ # Check that there are always four boxes in the variable grid and 0 elsewhere
+ num_boxes = jnp.sum(state.variable_grid == BOX)
+ assert num_boxes == jnp.array(4, jnp.int32)
+
+ num_boxes = jnp.sum(state.fixed_grid == BOX)
+ assert num_boxes == jnp.array(0, jnp.int32)
+
+ # Check that there are always 4 targets in the fixed grid and 0 elsewhere
+ num_targets = jnp.sum(state.variable_grid == TARGET)
+ assert num_targets == jnp.array(0, jnp.int32)
+
+ num_targets = jnp.sum(state.fixed_grid == TARGET)
+ assert num_targets == jnp.array(4, jnp.int32)
+
+ # Check that there is one agent in variable grid and 0 elsewhere
+ num_agents = jnp.sum(state.variable_grid == AGENT)
+ assert num_agents == jnp.array(1, jnp.int32)
+
+ num_agents = jnp.sum(state.fixed_grid == AGENT)
+ assert num_agents == jnp.array(0, jnp.int32)
+
+ # Check that the grid size remains constant
+ assert state.fixed_grid.shape == (10, 10)
+
+ # Check the agent is never in the same location as a wall
+ mask_agent = state.variable_grid == AGENT
+ mask_wall = state.fixed_grid == WALL
+ num_agents_on_wall = jnp.sum(mask_agent & mask_wall)
+ assert num_agents_on_wall == jnp.array(0, jnp.int32)
+
+ # Check the boxes are never on a wall
+ mask_boxes = state.variable_grid == BOX
+ mask_wall = state.fixed_grid == WALL
+ num_agents_on_wall = jnp.sum(mask_boxes & mask_wall)
+ assert num_agents_on_wall == jnp.array(0, jnp.int32)
+
+
+def test_sokoban__termination_timelimit(sokoban: Sokoban) -> None:
+ """Check that with random actions the environment terminates after
+ 120 steps"""
+
+ chex.clear_trace_counter()
+ step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1))
+
+ key = jax.random.PRNGKey(0)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban.reset(reset_key)
+
+ for _ in range(119):
+ action = jnp.array(random.randint(0, 4), jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ assert not timestep.last()
+
+ action = jnp.array(random.randint(0, 4), jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ assert timestep.last()
+
+
+def test_sokoban__termination_solved(sokoban_simple: Sokoban) -> None:
+ """Check that with correct sequence of actions to solve a trivial problem,
+ the environment terminates"""
+
+ correct_actions = [0, 2, 1] * 3 + [0]
+ wrong_actions = [0, 2, 1] * 3 + [2]
+
+ chex.clear_trace_counter()
+ step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1))
+
+ # Check that environment does terminate with right series of actions
+ key = jax.random.PRNGKey(0)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban_simple.reset(reset_key)
+
+ for action in correct_actions:
+ assert not timestep.last()
+
+ action = jnp.array(action, jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ assert timestep.last()
+
+ # Check that environment does not terminate with wrong series of actions
+ key = jax.random.PRNGKey(0)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban_simple.reset(reset_key)
+
+ for action in wrong_actions:
+ assert not timestep.last()
+
+ action = jnp.array(action, jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ assert not timestep.last()
+
+
+def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None:
+ """Check the reward function is correct when solving the trivial problem.
+ Every step should give -0.1, each box added to a target adds 1 and
+ solving adds an additional 10"""
+
+ # Correct actions that lead to placing a box every 3 actions
+ correct_actions = [0, 2, 1] * 3 + [0]
+
+ chex.clear_trace_counter()
+ step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1))
+
+ key = jax.random.PRNGKey(0)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban_simple.reset(reset_key)
+
+ for i, action in enumerate(correct_actions):
+ action = jnp.array(action, jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ if i % 3 == 0 and i != 9:
+ assert timestep.reward == jnp.array(0.9, jnp.float32)
+ elif i != 9:
+ assert timestep.reward == jnp.array(-0.1, jnp.float32)
+ else:
+ assert timestep.reward == jnp.array(10.9, jnp.float32)
+
+
+def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None:
+ """Test that we can run an episode without any errors."""
+ check_env_does_not_smoke(sokoban)
diff --git a/jumanji/environments/routing/sokoban/generator.py b/jumanji/environments/routing/sokoban/generator.py
new file mode 100644
index 000000000..e3ace0a4a
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/generator.py
@@ -0,0 +1,448 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import os
+import zipfile
+from os import listdir
+from os.path import isfile, join
+from typing import List, Tuple
+
+import chex
+import jax
+import jax.numpy as jnp
+import numpy as np
+import requests
+from huggingface_hub import hf_hub_download
+from tqdm import tqdm
+
+from jumanji.environments.routing.sokoban.constants import AGENT
+from jumanji.environments.routing.sokoban.types import State
+
+
+class Generator(abc.ABC):
+ """Defines the abstract `Generator` base class. A `Generator` is responsible
+ for generating a problem instance when the environment is reset.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ """ """
+
+ self._fixed_grids: chex.Array
+ self._variable_grids: chex.Array
+
+ @abc.abstractmethod
+ def __call__(self, rng_key: chex.PRNGKey) -> State:
+ """Generate a problem instance.
+
+ Args:
+ key: the Jax random number generation key.
+
+ Returns:
+ state: the generated problem instance.
+ """
+
+ def get_agent_coordinates(self, grid: chex.Array) -> chex.Array:
+ """Extracts the coordinates of the agent from a given grid with the
+ assumption there is only one agent in the grid.
+
+ Args:
+ grid: Array (uint8) of shape (num_rows, num_cols)
+
+ Returns:
+ location: (int32) of shape (2,)
+ """
+
+ coordinates = jnp.where(grid == AGENT, size=1)
+
+ x_coord = jnp.squeeze(coordinates[0])
+ y_coord = jnp.squeeze(coordinates[1])
+
+ return jnp.array([x_coord, y_coord])
+
+
+class DeepMindGenerator(Generator):
+ def __init__(
+ self,
+ difficulty: str,
+ split: str,
+ proportion_of_files: float = 1.0,
+ verbose: bool = False,
+ ) -> None:
+ self.difficulty = difficulty
+ self.verbose = verbose
+ self.proportion_of_files = proportion_of_files
+
+ # Set the cache path to user's home directory's .cache sub-directory
+ self.cache_path = os.path.join(
+ os.path.expanduser("~"), ".cache", "sokoban_dataset"
+ )
+
+ # Downloads data if not already downloaded
+ self._download_data()
+
+ self.train_data_dir = os.path.join(
+ self.cache_path, "boxoban-levels-master", self.difficulty
+ )
+
+ if self.difficulty in ["unfiltered", "medium"]:
+ if self.difficulty == "medium" and split == "test":
+ raise Exception(
+ "not a valid Deepmind Boxoban difficulty split" "combination"
+ )
+ self.train_data_dir = os.path.join(
+ self.train_data_dir,
+ split,
+ )
+
+ # Generates the dataset of sokoban levels
+ self._fixed_grids, self._variable_grids = self._generate_dataset()
+
+ def __call__(self, rng_key: chex.PRNGKey) -> State:
+ """Generate a random Boxoban problem from the Deepmind dataset.
+
+ Args:
+ rng_key: the Jax random number generation key.
+
+ Returns:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed
+ components of the problem.
+ variable_grid: Array (uint8) shape (num_rows, num_cols) the
+ variable components of the problem.
+ """
+
+ key, idx_key = jax.random.split(rng_key)
+ idx = jax.random.randint(
+ idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0]
+ )
+ fixed_grid = self._fixed_grids.take(idx, axis=0)
+ variable_grid = self._variable_grids.take(idx, axis=0)
+
+ initial_agent_location = self.get_agent_coordinates(variable_grid)
+
+ state = State(
+ key=key,
+ fixed_grid=fixed_grid,
+ variable_grid=variable_grid,
+ agent_location=initial_agent_location,
+ step_count=jnp.array(0, jnp.int32),
+ )
+
+ return state
+
+ def _generate_dataset(
+ self,
+ ) -> Tuple[chex.Array, chex.Array]:
+ """Parses the text files to generate a jax arrays (fixed and variable
+ grids representing the Boxoban dataset
+
+ Returns:
+ fixed_grid: Array (uint8) shape (dataset_size, num_rows, num_cols)
+ the fixed components of the problem.
+ variable_grid: Array (uint8) shape (dataset_size, num_rows,
+ num_cols) the variable components of the problem.
+ """
+
+ all_files = [
+ f
+ for f in listdir(self.train_data_dir)
+ if isfile(join(self.train_data_dir, f))
+ ]
+ # Only keep a few files if specified
+ all_files = all_files[: int(self.proportion_of_files * len(all_files))]
+
+ fixed_grids_list: List[chex.Array] = []
+ variable_grids_list: List[chex.Array] = []
+
+ for file in all_files:
+
+ source_file = join(self.train_data_dir, file)
+ current_map: List[str] = []
+ # parses a game file containing multiple games
+ with open(source_file, "r") as sf:
+ for line in sf.readlines():
+ if ";" in line and current_map:
+ fixed_grid, variable_grid = convert_level_to_array(current_map)
+
+ fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8))
+ variable_grids_list.append(
+ jnp.array(variable_grid, dtype=jnp.uint8)
+ )
+
+ current_map = []
+ if "#" == line[0]:
+ current_map.append(line.strip())
+
+ fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8))
+ variable_grids_list.append(jnp.array(variable_grid, dtype=jnp.uint8))
+
+ fixed_grids = jnp.asarray(fixed_grids_list, jnp.uint8)
+ variable_grids = jnp.asarray(variable_grids_list, jnp.uint8)
+
+ return fixed_grids, variable_grids
+
+ def _download_data(self) -> None:
+ """Downloads the deepmind boxoban dataset from github into text files"""
+
+ # Check if the cache directory exists, if not, create it
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ # Check if the dataset is already downloaded in the cache
+ dataset_path = os.path.join(self.cache_path, "boxoban-levels-master")
+ if not os.path.exists(dataset_path):
+ url = "https://github.com/deepmind/boxoban-levels/archive/master.zip"
+ if self.verbose:
+ print("Boxoban: Pregenerated levels not downloaded.")
+ print('Starting download from "{}"'.format(url))
+
+ response = requests.get(url, stream=True)
+
+ if response.status_code != 200:
+ raise Exception("Could not download levels")
+
+ path_to_zip_file = os.path.join(
+ self.cache_path, "boxoban_levels-master.zip"
+ )
+ with open(path_to_zip_file, "wb") as handle:
+ for data in tqdm(response.iter_content()):
+ handle.write(data)
+
+ with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
+ zip_ref.extractall(self.cache_path)
+
+
+class HuggingFaceDeepMindGenerator(Generator):
+ """Instance generator that generates a random problem from the DeepMind
+ Boxoban dataset a popular dataset for comparing Reinforcement Learning
+ algorithms and Planning Algorithms. The dataset has unfiltered, medium and
+ hard versions. The unfiltered dataset contain train, test and valid
+ splits. The Medium has train and valid splits available. And the hard set
+ contains just a small number of problems. The problems are all guaranteed
+ to be solvable.
+ """
+
+ def __init__(
+ self,
+ dataset_name: str,
+ proportion_of_files: float = 1.0,
+ ) -> None:
+ """Instantiates a `DeepMindGenerator`.
+
+ Args:
+ dataset_name: the name of the dataset to use. Choices are:
+ - unfiltered-train,
+ - unfiltered-valid,
+ - unfiltered-test,
+ - medium-train,
+ - medium-test,
+ - hard.
+ proportion_of_files: float between (0,1) for the proportion of
+ files to use in the dataset .
+ """
+
+ self.dataset_name = dataset_name
+ self.proportion_of_files = proportion_of_files
+
+ dataset_file = hf_hub_download(
+ repo_id="InstaDeepAI/boxoban-levels",
+ filename=f"{dataset_name}.npy",
+ )
+
+ with open(dataset_file, "rb") as f:
+ dataset = np.load(f)
+
+ # Convert to jax arrays and resize using proportion_of_files
+ length = int(proportion_of_files * dataset.shape[0])
+ self._fixed_grids = jnp.asarray(dataset[:length, ..., 0], jnp.uint8)
+ self._variable_grids = jnp.asarray(dataset[:length, ..., 1], jnp.uint8)
+
+ def __call__(self, rng_key: chex.PRNGKey) -> State:
+ """Generate a random Boxoban problem from the Deepmind dataset.
+
+ Args:
+ rng_key: the Jax random number generation key.
+
+ Returns:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed
+ components of the problem.
+ variable_grid: Array (uint8) shape (num_rows, num_cols) the
+ variable components of the problem.
+ """
+
+ key, idx_key = jax.random.split(rng_key)
+ idx = jax.random.randint(
+ idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0]
+ )
+ fixed_grid = self._fixed_grids.take(idx, axis=0)
+ variable_grid = self._variable_grids.take(idx, axis=0)
+
+ initial_agent_location = self.get_agent_coordinates(variable_grid)
+
+ state = State(
+ key=key,
+ fixed_grid=fixed_grid,
+ variable_grid=variable_grid,
+ agent_location=initial_agent_location,
+ step_count=jnp.array(0, jnp.int32),
+ )
+
+ return state
+
+
+class ToyGenerator(Generator):
+ def __call__(
+ self,
+ rng_key: chex.PRNGKey,
+ ) -> State:
+ """Generate a random Boxoban problem from the toy 2 problem dataset.
+
+ Args:
+ rng_key: the Jax random number generation key.
+
+ Returns:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed
+ components of the problem.
+ variable_grid: Array (uint8) shape (num_rows, num_cols) the
+ variable components of the problem.
+ """
+
+ key, idx_key = jax.random.split(rng_key)
+
+ level1 = [
+ "##########",
+ "# @ #",
+ "# $ . #",
+ "# $# . #",
+ "# .#$ # ",
+ "# . # $ # ",
+ "# #",
+ "##########",
+ "##########",
+ "##########",
+ ]
+
+ level2 = [
+ "##########",
+ "# #",
+ "#$ # . #",
+ "# # $ # .#",
+ "# .# $ #",
+ "# @ # . $#",
+ "# #",
+ "##########",
+ "##########",
+ "##########",
+ ]
+
+ game1_fixed, game1_variable = convert_level_to_array(level1)
+ game2_fixed, game2_variable = convert_level_to_array(level2)
+
+ games_fixed = jnp.stack([game1_fixed, game2_fixed])
+ games_variable = jnp.stack([game1_variable, game2_variable])
+
+ game_index = jax.random.randint(
+ key=idx_key,
+ shape=(),
+ minval=0,
+ maxval=games_fixed.shape[0],
+ )
+
+ initial_agent_location = self.get_agent_coordinates(games_variable[game_index])
+
+ state = State(
+ key=key,
+ fixed_grid=games_fixed[game_index],
+ variable_grid=games_variable[game_index],
+ agent_location=initial_agent_location,
+ step_count=jnp.array(0, jnp.int32),
+ )
+
+ return state
+
+
+class SimpleSolveGenerator(Generator):
+ def __call__(
+ self,
+ key: chex.PRNGKey,
+ ) -> State:
+ """Generate a trivial Boxoban problem.
+
+ Args:
+ key: the Jax random number generation key.
+
+ Returns:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed
+ components of the problem.
+ variable_grid: Array (uint8) shape (num_rows, num_cols) the
+ variable components of the problem.
+ """
+ level1 = [
+ "##########",
+ "# ##",
+ "# .... #",
+ "# $$$$ ##",
+ "# @ # #",
+ "# # # ",
+ "# #",
+ "##########",
+ "##########",
+ "##########",
+ ]
+
+ game_fixed, game_variable = convert_level_to_array(level1)
+
+ initial_agent_location = self.get_agent_coordinates(game_variable)
+
+ state = State(
+ key=key,
+ fixed_grid=game_fixed,
+ variable_grid=game_variable,
+ agent_location=initial_agent_location,
+ step_count=jnp.array(0, jnp.int32),
+ )
+
+ return state
+
+
+def convert_level_to_array(level: List[str]) -> Tuple[chex.Array, chex.Array]:
+ """Converts text representation of levels to a tuple of Jax arrays
+ representing the fixed elements of the Boxoban problem and the variable
+ elements
+
+ Args:
+ level: List of str representing a boxoban level.
+
+ Returns:
+ fixed_grid: Array (uint8) shape (num_rows, num_cols)
+ the fixed components of the problem.
+ variable_grid: Array (uint8) shape (num_rows,
+ num_cols) the variable components of the problem.
+ """
+
+ # Define the mappings
+ mapping = {
+ "#": (1, 0),
+ ".": (2, 0),
+ "@": (0, 3),
+ "$": (0, 4),
+ " ": (0, 0), # empty cell
+ }
+
+ fixed = [[mapping[cell][0] for cell in row] for row in level]
+ variable = [[mapping[cell][1] for cell in row] for row in level]
+
+ return jnp.array(fixed, jnp.uint8), jnp.array(variable, jnp.uint8)
diff --git a/jumanji/environments/routing/sokoban/generator_test.py b/jumanji/environments/routing/sokoban/generator_test.py
new file mode 100644
index 000000000..0c766b632
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/generator_test.py
@@ -0,0 +1,196 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import List
+
+import chex
+import jax
+import jax.numpy as jnp
+import pytest
+
+from jumanji.environments.routing.sokoban.env import Sokoban
+from jumanji.environments.routing.sokoban.generator import (
+ DeepMindGenerator,
+ HuggingFaceDeepMindGenerator,
+)
+
+
+def test_sokoban__hugging_generator_creation() -> None:
+ """checks we can create datasets for all valid boxoban datasets and
+ perform a jitted step"""
+
+ datasets = [
+ "unfiltered-train",
+ "unfiltered-test",
+ "unfiltered-valid",
+ "medium-train",
+ "medium-valid",
+ "hard",
+ ]
+
+ for dataset in datasets:
+
+ chex.clear_trace_counter()
+
+ env = Sokoban(
+ generator=HuggingFaceDeepMindGenerator(
+ dataset_name=dataset,
+ proportion_of_files=1,
+ )
+ )
+
+ print(env.generator._fixed_grids.shape)
+
+ step_fn = jax.jit(chex.assert_max_traces(env.step, n=1))
+
+ key = jax.random.PRNGKey(0)
+ state, timestep = env.reset(key)
+
+ step_count = 0
+ for _ in range(120):
+ action = jnp.array(random.randint(0, 4), jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ # Check step_count increases after each step
+ step_count += 1
+ assert state.step_count == step_count
+ assert timestep.observation.step_count == step_count
+
+
+def test_sokoban__hugging_generator_different_problems() -> None:
+ """checks that resetting with different keys leads to different problems"""
+
+ chex.clear_trace_counter()
+
+ env = Sokoban(
+ generator=HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-train",
+ proportion_of_files=1,
+ )
+ )
+
+ key1 = jax.random.PRNGKey(0)
+ state1, timestep1 = env.reset(key1)
+
+ key2 = jax.random.PRNGKey(1)
+ state2, timestep2 = env.reset(key2)
+
+ # Check that different resets lead to different problems
+ assert not jnp.array_equal(state2.fixed_grid, state1.fixed_grid)
+ assert not jnp.array_equal(state2.variable_grid, state1.variable_grid)
+
+
+def test_sokoban__hugging_generator_same_problems() -> None:
+ """checks that resettting with the same key leads to the same problems"""
+
+ chex.clear_trace_counter()
+
+ env = Sokoban(
+ generator=HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-train",
+ proportion_of_files=1,
+ )
+ )
+
+ key1 = jax.random.PRNGKey(0)
+ state1, timestep1 = env.reset(key1)
+
+ key2 = jax.random.PRNGKey(0)
+ state2, timestep2 = env.reset(key2)
+
+ assert jnp.array_equal(state2.fixed_grid, state1.fixed_grid)
+ assert jnp.array_equal(state2.variable_grid, state1.variable_grid)
+
+
+def test_sokoban__hugging_generator_proportion_of_problems() -> None:
+ """checks that generator initialises correct number of problems"""
+
+ chex.clear_trace_counter()
+
+ unfiltered_dataset_size = 900000
+
+ generator_full = HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-train",
+ proportion_of_files=1,
+ )
+
+ assert jnp.array_equal(
+ generator_full._fixed_grids.shape,
+ (unfiltered_dataset_size, 10, 10),
+ )
+
+ generator_half_full = HuggingFaceDeepMindGenerator(
+ dataset_name="unfiltered-train",
+ proportion_of_files=0.5,
+ )
+
+ assert jnp.array_equal(
+ generator_half_full._fixed_grids.shape,
+ (unfiltered_dataset_size / 2, 10, 10),
+ )
+
+
+def test_sokoban__deepmind_generator_creation() -> None:
+ """checks we can create datasets for all valid boxoban datasets"""
+
+ # Different datasets with varying proportion of files to keep rutime low
+ valid_datasets: List[List] = [
+ ["unfiltered", "train", 0.01],
+ ["unfiltered", "test", 1],
+ ["unfiltered", "valid", 0.02],
+ ["medium", "train", 0.01],
+ ["medium", "valid", 0.02],
+ ["hard", None, 1],
+ ]
+
+ for dataset in valid_datasets:
+
+ chex.clear_trace_counter()
+
+ env = Sokoban(
+ generator=DeepMindGenerator(
+ difficulty=dataset[0],
+ split=dataset[1],
+ proportion_of_files=dataset[2],
+ )
+ )
+
+ assert env.generator._fixed_grids.shape[0] > 0
+
+
+def test_sokoban__deepmind_invalid_creation() -> None:
+ """checks that asking for invalid difficulty, split, proportion leads to
+ exception"""
+
+ # Different datasets with varying proportion of files to keep rutime low
+ valid_datasets: List[List] = [
+ ["medium", "test", 0.01],
+ ["mediumy", "train", 0.01],
+ ["hardy", "train", 0.01],
+ ["unfiltered", None, 0.01],
+ ]
+
+ for dataset in valid_datasets:
+
+ chex.clear_trace_counter()
+
+ with pytest.raises(Exception):
+ _ = Sokoban(
+ generator=DeepMindGenerator(
+ difficulty=dataset[0],
+ split=dataset[1],
+ proportion_of_files=dataset[2],
+ )
+ )
diff --git a/jumanji/environments/routing/sokoban/imgs/agent.png b/jumanji/environments/routing/sokoban/imgs/agent.png
new file mode 100644
index 000000000..00298ce9b
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/agent.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/agent_on_target.png b/jumanji/environments/routing/sokoban/imgs/agent_on_target.png
new file mode 100644
index 000000000..3e8310d25
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/agent_on_target.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/box.png b/jumanji/environments/routing/sokoban/imgs/box.png
new file mode 100644
index 000000000..9a2497df0
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/box_on_target.png b/jumanji/environments/routing/sokoban/imgs/box_on_target.png
new file mode 100644
index 000000000..74629af03
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box_on_target.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/box_target.png b/jumanji/environments/routing/sokoban/imgs/box_target.png
new file mode 100644
index 000000000..41cbc8f4e
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box_target.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/floor.png b/jumanji/environments/routing/sokoban/imgs/floor.png
new file mode 100644
index 000000000..79f053277
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/floor.png differ
diff --git a/jumanji/environments/routing/sokoban/imgs/wall.png b/jumanji/environments/routing/sokoban/imgs/wall.png
new file mode 100644
index 000000000..8da06c8d6
Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/wall.png differ
diff --git a/jumanji/environments/routing/sokoban/reward.py b/jumanji/environments/routing/sokoban/reward.py
new file mode 100644
index 000000000..775122293
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/reward.py
@@ -0,0 +1,120 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+import chex
+import jax.numpy as jnp
+
+from jumanji.environments.routing.sokoban.constants import (
+ BOX,
+ LEVEL_COMPLETE_BONUS,
+ N_BOXES,
+ SINGLE_BOX_BONUS,
+ STEP_BONUS,
+ TARGET,
+)
+from jumanji.environments.routing.sokoban.types import State
+
+
+class RewardFn(abc.ABC):
+ @abc.abstractmethod
+ def __call__(
+ self,
+ state: State,
+ action: chex.Numeric,
+ next_state: State,
+ ) -> chex.Numeric:
+ """Compute the reward based on the current state,
+ the chosen action, the next state.
+ """
+
+ def count_targets(self, state: State) -> chex.Array:
+ """
+ Calculates the number of boxes on targets.
+
+ Args:
+ state: `State` object representing the current state of the
+ environment.
+
+ Returns:
+ n_targets: Array (int32) of shape () specifying the number of boxes
+ on targets.
+ """
+
+ mask_box = state.variable_grid == BOX
+ mask_target = state.fixed_grid == TARGET
+
+ num_boxes_on_targets = jnp.sum(mask_box & mask_target)
+
+ return num_boxes_on_targets
+
+
+class SparseReward(RewardFn):
+ def __call__(
+ self,
+ state: State,
+ action: chex.Array,
+ next_state: State,
+ ) -> chex.Array:
+ """
+ Implements the sparse reward function in the Sokoban environment.
+
+ Args:
+ state: `State` object The current state of the environment.
+ action: Array (int32) shape () representing the action taken.
+ next_state: `State` object The next state of the environment.
+
+ Returns:
+ reward: Array (float32) of shape () specifying the reward received
+ at transition
+ """
+
+ next_num_box_target = self.count_targets(next_state)
+
+ level_completed = next_num_box_target == N_BOXES
+
+ return LEVEL_COMPLETE_BONUS * level_completed
+
+
+class DenseReward(RewardFn):
+ def __call__(
+ self,
+ state: State,
+ action: chex.Array,
+ next_state: State,
+ ) -> chex.Array:
+ """
+ Implements the dense reward function in the Sokoban environment.
+
+ Args:
+ state: `State` object The current state of the environment.
+ action: Array (int32) shape () representing the action taken.
+ next_state: `State` object The next state of the environment.
+
+ Returns:
+ reward: Array (float32) of shape () specifying the reward received
+ at transition
+ """
+
+ num_box_target = self.count_targets(state)
+ next_num_box_target = self.count_targets(next_state)
+
+ level_completed = next_num_box_target == N_BOXES
+
+ return (
+ SINGLE_BOX_BONUS * (next_num_box_target - num_box_target)
+ + LEVEL_COMPLETE_BONUS * level_completed
+ + STEP_BONUS
+ )
diff --git a/jumanji/environments/routing/sokoban/reward_test.py b/jumanji/environments/routing/sokoban/reward_test.py
new file mode 100644
index 000000000..c8ae5b08b
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/reward_test.py
@@ -0,0 +1,74 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+
+import chex
+import jax
+import jax.numpy as jnp
+import pytest
+
+from jumanji.environments.routing.sokoban.env import Sokoban
+from jumanji.environments.routing.sokoban.generator import SimpleSolveGenerator
+from jumanji.types import TimeStep
+
+
+@pytest.fixture(scope="session")
+def sokoban_simple() -> Sokoban:
+ env = Sokoban(generator=SimpleSolveGenerator())
+ return env
+
+
+def test_sokoban__reward_function_random(sokoban_simple: Sokoban) -> None:
+ """Check the reward function is correct when randomly acting in the
+ trivial problem, where accidently pushing boxes onto targets is likely.
+ Every step should give -0.1, each box pushed on adds 1 , each box removed
+ on takes away 1 ,solving adds an additional 10"""
+
+ def check_correct_reward(
+ timestep: TimeStep,
+ num_boxes_on_targets_new: chex.Array,
+ num_boxes_on_targets: chex.Array,
+ ) -> None:
+
+ if num_boxes_on_targets_new == jnp.array(4, jnp.int32):
+ assert timestep.reward == jnp.array(10.9, jnp.float32)
+ elif num_boxes_on_targets_new - num_boxes_on_targets > jnp.array(0, jnp.int32):
+ assert timestep.reward == jnp.array(0.9, jnp.float32)
+ elif num_boxes_on_targets_new - num_boxes_on_targets < jnp.array(0, jnp.int32):
+ assert timestep.reward == jnp.array(-1.1, jnp.float32)
+ else:
+ assert timestep.reward == jnp.array(-0.1, jnp.float32)
+
+ for i in range(5):
+ chex.clear_trace_counter()
+ step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1))
+
+ key = jax.random.PRNGKey(i)
+ reset_key, step_key = jax.random.split(key)
+ state, timestep = sokoban_simple.reset(reset_key)
+
+ num_boxes_on_targets = sokoban_simple.reward_fn.count_targets(state)
+
+ for _ in range(120):
+ action = jnp.array(random.randint(0, 4), jnp.int32)
+ state, timestep = step_fn(state, action)
+
+ num_boxes_on_targets_new = sokoban_simple.reward_fn.count_targets(state)
+
+ check_correct_reward(
+ timestep, num_boxes_on_targets_new, num_boxes_on_targets
+ )
+
+ num_boxes_on_targets = num_boxes_on_targets_new
diff --git a/jumanji/environments/routing/sokoban/types.py b/jumanji/environments/routing/sokoban/types.py
new file mode 100644
index 000000000..eeb561e76
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/types.py
@@ -0,0 +1,53 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, NamedTuple
+
+import chex
+
+if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
+ from dataclasses import dataclass
+else:
+ from chex import dataclass
+
+
+@dataclass
+class State:
+ """
+ key: random key used for auto-reset.
+ fixed_grid: Array (uint8) shape (n_rows, n_cols) array representing the
+ fixed elements of a sokoban problem.
+ variable_grid: Array (uint8) shape (n_rows, n_cols) array representing the
+ variable elements of a sokoban problem.
+ agent_location: Array (int32) shape (2,)
+ step_count: Array (int32) shape ()
+ """
+
+ key: chex.PRNGKey
+ fixed_grid: chex.Array
+ variable_grid: chex.Array
+ agent_location: chex.Array
+ step_count: chex.Array
+
+
+class Observation(NamedTuple):
+ """
+ The observation returned by the sokoban environment.
+ grid: Array (uint8) shape (n_rows, n_cols, 2) array representing the
+ variable and fixed grids.
+ step_count: Array (int32) shape () the index of the current step.
+ """
+
+ grid: chex.Array
+ step_count: chex.Array
diff --git a/jumanji/environments/routing/sokoban/viewer.py b/jumanji/environments/routing/sokoban/viewer.py
new file mode 100644
index 000000000..db5716704
--- /dev/null
+++ b/jumanji/environments/routing/sokoban/viewer.py
@@ -0,0 +1,219 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Sequence, Tuple
+
+import chex
+import matplotlib.animation
+import matplotlib.cm
+import matplotlib.pyplot as plt
+import numpy as np
+import pkg_resources
+from numpy.typing import NDArray
+from PIL import Image
+
+import jumanji.environments
+from jumanji.viewer import Viewer
+
+
+class BoxViewer(Viewer):
+ FIGURE_SIZE = (10.0, 10.0)
+
+ def __init__(
+ self,
+ name: str,
+ grid_combine: Callable,
+ ) -> None:
+ """
+ Viewer for a `Sokoban` environment using images from
+ https://github.com/mpSchrader/gym-sokoban.
+
+ Args:
+ name: the window name to be used when initialising the window.
+ grid_combine: function for combining fixed_grid and variable grid
+ """
+ self._name = name
+ self.NUM_COLORS = 10
+ self.grid_combine = grid_combine
+ self._display = self._display_rgb_array
+ self._animation: Optional[matplotlib.animation.Animation] = None
+
+ image_names = [
+ "floor",
+ "wall",
+ "box_target",
+ "agent",
+ "box",
+ "agent_on_target",
+ "box_on_target",
+ ]
+
+ def get_image(image_name: str) -> Image.Image:
+ img_path = pkg_resources.resource_filename(
+ "jumanji", f"environments/routing/sokoban/imgs/{image_name}.png"
+ )
+ return Image.open(img_path)
+
+ self.images = [get_image(image_name) for image_name in image_names]
+
+ def render(self, state: chex.Array) -> Optional[NDArray]:
+ """Render the given state of the `Sokoban` environment.
+
+ Args:
+ state: the environment state to render.
+ """
+
+ self._clear_display()
+ fig, ax = self._get_fig_ax()
+ ax.clear()
+ self._add_grid_image(state, ax)
+ return self._display(fig)
+
+ def animate(
+ self,
+ states: Sequence[chex.Array],
+ interval: int = 200,
+ save_path: Optional[str] = None,
+ ) -> matplotlib.animation.FuncAnimation:
+ """Create an animation from a sequence of environment states.
+
+ Args:
+ states: sequence of environment states corresponding to
+ consecutive timesteps.
+ interval: delay between frames in milliseconds, default to 200.
+ save_path: the path where the animation file should be saved. If
+ it is None, the plot will not be saved.
+
+ Returns:
+ Animation that can be saved as a GIF, MP4, or rendered with HTML.
+ """
+ fig, ax = plt.subplots(
+ num=f"{self._name}Animation", figsize=BoxViewer.FIGURE_SIZE
+ )
+ plt.close(fig)
+
+ def make_frame(state_index: int) -> None:
+ ax.clear()
+ state = states[state_index]
+ self._add_grid_image(state, ax)
+
+ # Create the animation object.
+ self._animation = matplotlib.animation.FuncAnimation(
+ fig,
+ make_frame,
+ frames=len(states),
+ interval=interval,
+ )
+
+ # Save the animation as a gif.
+ if save_path:
+ self._animation.save(save_path)
+
+ return self._animation
+
+ def close(self) -> None:
+ """Perform any necessary cleanup.
+
+ Environments will automatically :meth:`close()` themselves when
+ garbage collected or when the program exits.
+ """
+ plt.close(self._name)
+
+ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
+ """
+ Fetch or create a matplotlib figure and its associated axes.
+
+ Returns:
+ fig: (plt.Figure) A matplotlib figure object
+ axes: (plt.Axes) The axes associated with the figure.
+ """
+ recreate = not plt.fignum_exists(self._name)
+ fig = plt.figure(self._name, BoxViewer.FIGURE_SIZE)
+ if recreate:
+ if not plt.isinteractive():
+ fig.show()
+ ax = fig.add_subplot()
+ else:
+ ax = fig.get_axes()[0]
+ return fig, ax
+
+ def _add_grid_image(self, state: chex.Array, ax: plt.Axes) -> None:
+ """
+ Add a grid image to the provided axes.
+
+ Args:
+ state: 'State' object representing a state of Sokoban.
+ ax: (plt.Axes) object where the state image will be added.
+ """
+ grid = self.grid_combine(state.variable_grid, state.fixed_grid)
+
+ self._draw_grid(grid, ax)
+ ax.set_axis_off()
+ ax.set_aspect(1)
+ ax.relim()
+ ax.autoscale_view()
+
+ def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None:
+ """
+ Draw a grid onto provided axes.
+
+ Args:
+ grid: Array () of shape ().
+ ax: (plt.Axes) The axes on which to draw the grid.
+ """
+
+ cols, rows = grid.shape
+
+ for col in range(cols):
+ for row in range(rows):
+ self._draw_grid_cell(grid[row, col], 9 - row, col, ax)
+
+ def _draw_grid_cell(
+ self, cell_value: int, row: int, col: int, ax: plt.Axes
+ ) -> None:
+ """
+ Draw a single cell of the grid.
+
+ Args:
+ cell_value: int representing the cell's value determining its image.
+ row: int representing the cell's row index.
+ col: int representing the cell's col index.
+ ax: (plt.Axes) The axes on which to draw the cell.
+ """
+ cell_value = int(cell_value)
+ image = self.images[cell_value]
+ ax.imshow(image, extent=(col, col + 1, row, row + 1))
+
+ def _clear_display(self) -> None:
+ """
+ Clear the current notebook display if the environment is a notebook.
+ """
+
+ if jumanji.environments.is_notebook():
+ import IPython.display
+
+ IPython.display.clear_output(True)
+
+ def _display_rgb_array(self, fig: plt.Figure) -> NDArray:
+ """
+ Convert the given figure to an RGB array.
+
+ Args:
+ fig: (plt.Figure) The figure to be converted.
+
+ Returns:
+ NDArray: The RGB array representation of the figure.
+ """
+ fig.canvas.draw()
+ return np.asarray(fig.canvas.buffer_rgba())
diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py
index 83236f872..0428e646c 100644
--- a/jumanji/environments/routing/tsp/env.py
+++ b/jumanji/environments/routing/tsp/env.py
@@ -80,7 +80,7 @@ class TSP(Environment[State]):
```python
from jumanji.environments import TSP
env = TSP()
- key = jax.random.key(0)
+ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml
index a110c89b8..458d60b07 100644
--- a/jumanji/training/configs/config.yaml
+++ b/jumanji/training/configs/config.yaml
@@ -1,6 +1,6 @@
defaults:
- _self_
- - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sudoku, tetris, tsp]
+ - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp]
agent: random # [random, a2c]
diff --git a/jumanji/training/configs/env/sokoban.yaml b/jumanji/training/configs/env/sokoban.yaml
new file mode 100644
index 000000000..6baf4f2d0
--- /dev/null
+++ b/jumanji/training/configs/env/sokoban.yaml
@@ -0,0 +1,26 @@
+name: sokoban
+registered_version: Sokoban-v0
+
+network:
+ channels: [256,256,512,512]
+ policy_layers: [64, 64]
+ value_layers: [128, 128]
+
+training:
+ num_epochs: 1000
+ num_learner_steps_per_epoch: 500
+ n_steps: 20
+ total_batch_size: 128
+
+evaluation:
+ eval_total_batch_size: 1024
+ greedy_eval_total_batch_size: 1024
+
+a2c:
+ normalize_advantage: True
+ discount_factor: 0.97
+ bootstrapping_factor: 0.95
+ l_pg: 1.0
+ l_td: 1.0
+ l_en: 0.01
+ learning_rate: 3e-4
diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py
index 1fb1df18d..82ad0ae65 100644
--- a/jumanji/training/networks/__init__.py
+++ b/jumanji/training/networks/__init__.py
@@ -78,6 +78,10 @@
make_actor_critic_networks_snake,
)
from jumanji.training.networks.snake.random import make_random_policy_snake
+from jumanji.training.networks.sokoban.actor_critic import (
+ make_actor_critic_networks_sokoban,
+)
+from jumanji.training.networks.sokoban.random import make_random_policy_sokoban
from jumanji.training.networks.sudoku.actor_critic import (
make_cnn_actor_critic_networks_sudoku,
make_equivariant_actor_critic_networks_sudoku,
diff --git a/jumanji/training/networks/sokoban/__init__.py b/jumanji/training/networks/sokoban/__init__.py
new file mode 100644
index 000000000..21db9ec1c
--- /dev/null
+++ b/jumanji/training/networks/sokoban/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py
new file mode 100644
index 000000000..968180c60
--- /dev/null
+++ b/jumanji/training/networks/sokoban/actor_critic.py
@@ -0,0 +1,115 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence
+
+import chex
+import haiku as hk
+import jax
+import jax.numpy as jnp
+
+from jumanji.environments.routing.sokoban import Observation, Sokoban
+from jumanji.training.networks.actor_critic import (
+ ActorCriticNetworks,
+ FeedForwardNetwork,
+)
+from jumanji.training.networks.parametric_distribution import (
+ CategoricalParametricDistribution,
+)
+
+
+def make_actor_critic_networks_sokoban(
+ sokoban: Sokoban,
+ channels: Sequence[int],
+ policy_layers: Sequence[int],
+ value_layers: Sequence[int],
+) -> ActorCriticNetworks:
+ """Make actor-critic networks for the `Sokoban` environment."""
+ num_actions = sokoban.action_spec().num_values
+ parametric_action_distribution = CategoricalParametricDistribution(
+ num_actions=num_actions
+ )
+
+ policy_network = make_sokoban_cnn(
+ num_outputs=num_actions,
+ mlp_units=policy_layers,
+ channels=channels,
+ time_limit=sokoban.time_limit,
+ )
+ value_network = make_sokoban_cnn(
+ num_outputs=1,
+ mlp_units=value_layers,
+ channels=channels,
+ time_limit=sokoban.time_limit,
+ )
+
+ return ActorCriticNetworks(
+ policy_network=policy_network,
+ value_network=value_network,
+ parametric_action_distribution=parametric_action_distribution,
+ )
+
+
+def make_sokoban_cnn(
+ num_outputs: int,
+ mlp_units: Sequence[int],
+ channels: Sequence[int],
+ time_limit: int,
+) -> FeedForwardNetwork:
+ def network_fn(observation: Observation) -> chex.Array:
+
+ # Iterate over the channels sequence to create convolutional layers
+ layers = []
+ for i, conv_n_channels in enumerate(channels):
+ layers.append(hk.Conv2D(conv_n_channels, (3, 3), stride=2 if i == 0 else 1))
+ layers.append(jax.nn.relu)
+
+ layers.append(hk.Flatten())
+
+ torso = hk.Sequential(layers)
+
+ x_processed = preprocess_input(observation.grid)
+
+ embedding = torso(x_processed)
+
+ norm_step_count = jnp.expand_dims(observation.step_count / time_limit, axis=-1)
+ embedding = jnp.concatenate([embedding, norm_step_count], axis=-1)
+ head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False)
+ if num_outputs == 1:
+ value = jnp.squeeze(head(embedding), axis=-1)
+ return value
+ else:
+ logits = head(embedding)
+
+ return logits
+
+ init, apply = hk.without_apply_rng(hk.transform(network_fn))
+ return FeedForwardNetwork(init=init, apply=apply)
+
+
+def preprocess_input(
+ input_array: chex.Array,
+) -> chex.Array:
+
+ one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype(
+ jnp.float32
+ )
+
+ one_hot_array_variable = jnp.equal(input_array[..., 1:2], jnp.array([1, 2])).astype(
+ jnp.float32
+ )
+
+ total = jnp.concatenate((one_hot_array_fixed, one_hot_array_variable), axis=-1)
+
+ return total
diff --git a/jumanji/training/networks/sokoban/random.py b/jumanji/training/networks/sokoban/random.py
new file mode 100644
index 000000000..8b428174f
--- /dev/null
+++ b/jumanji/training/networks/sokoban/random.py
@@ -0,0 +1,35 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import chex
+import jax
+import jax.numpy as jnp
+
+from jumanji.environments.routing.sokoban import Observation
+from jumanji.training.networks.protocols import RandomPolicy
+
+
+def categorical_random(
+ observation: Observation,
+ key: chex.PRNGKey,
+) -> chex.Array:
+ logits = jnp.zeros(shape=(observation.grid.shape[0], 4))
+
+ action = jax.random.categorical(key, logits)
+ return action
+
+
+def make_random_policy_sokoban() -> RandomPolicy:
+ """Make random policy for the `Sokoban` environment."""
+ return categorical_random
diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py
index dee956345..e2d2b9890 100644
--- a/jumanji/training/setup_train.py
+++ b/jumanji/training/setup_train.py
@@ -40,6 +40,7 @@
RobotWarehouse,
RubiksCube,
Snake,
+ Sokoban,
Sudoku,
Tetris,
)
@@ -178,6 +179,9 @@ def _setup_random_policy( # noqa: CCR001
elif cfg.env.name == "maze":
assert isinstance(env.unwrapped, Maze)
random_policy = networks.make_random_policy_maze()
+ elif cfg.env.name == "sokoban":
+ assert isinstance(env.unwrapped, Sokoban)
+ random_policy = networks.make_random_policy_sokoban()
elif cfg.env.name == "connector":
assert isinstance(env.unwrapped, Connector)
random_policy = networks.make_random_policy_connector()
@@ -326,6 +330,14 @@ def _setup_actor_critic_neworks( # noqa: CCR001
policy_layers=cfg.env.network.policy_layers,
value_layers=cfg.env.network.value_layers,
)
+ elif cfg.env.name == "sokoban":
+ assert isinstance(env.unwrapped, Sokoban)
+ actor_critic_networks = networks.make_actor_critic_networks_sokoban(
+ sokoban=env.unwrapped,
+ channels=cfg.env.network.channels,
+ policy_layers=cfg.env.network.policy_layers,
+ value_layers=cfg.env.network.value_layers,
+ )
elif cfg.env.name == "cleaner":
assert isinstance(env.unwrapped, Cleaner)
actor_critic_networks = networks.make_actor_critic_networks_cleaner(
diff --git a/mkdocs.yml b/mkdocs.yml
index 33c290c91..39dbca1cd 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -38,6 +38,7 @@ nav:
- RobotWarehouse: environments/robot_warehouse.md
- Snake: environments/snake.md
- TSP: environments/tsp.md
+ - Sokoban: environments/sokoban.md
- PacMan: environments/pac_man.md
- User Guides:
- Advanced Usage: guides/advanced_usage.md
@@ -68,6 +69,7 @@ nav:
- RobotWarehouse: api/environments/robot_warehouse.md
- Snake: api/environments/snake.md
- TSP: api/environments/tsp.md
+ - Sokoban: api/environments/sokoban.md
- PacMan: api/environments/pac_man.md
- Wrappers: api/wrappers.md
- Types: api/types.md
diff --git a/pyproject.toml b/pyproject.toml
index a6321cfd9..72558473a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,5 +41,6 @@ module = [
"haiku.*",
"hydra.*",
"omegaconf.*",
+ "huggingface_hub.*",
]
ignore_missing_imports = true
diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt
index df33e2e2b..ea0b437c2 100644
--- a/requirements/requirements-dev.txt
+++ b/requirements/requirements-dev.txt
@@ -24,3 +24,6 @@ pytest-xdist
pytype
scipy>=1.7.3
testfixtures
+types-Pillow
+types-requests
+types-setuptools