diff --git a/.gitignore b/.gitignore index 4e62578a6..7a4e033c2 100644 --- a/.gitignore +++ b/.gitignore @@ -154,3 +154,4 @@ cython_debug/ jumanji_env/ **/outputs/ *.xml +.sokoban_cache/ diff --git a/README.md b/README.md index e70ba830a..4b2fb4cdf 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@
RobotWarehouse + RobotWarehouse
@@ -112,6 +113,7 @@ problems. | 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) | | Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) | | ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pacman/) | [doc](https://instadeepai.github.io/jumanji/environments/pacman/) +| 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |

Installation 🎬

diff --git a/docs/api/environments/sokoban.md b/docs/api/environments/sokoban.md new file mode 100644 index 000000000..1554480fc --- /dev/null +++ b/docs/api/environments/sokoban.md @@ -0,0 +1,9 @@ +::: jumanji.environments.routing.sokoban.env.Sokoban + selection: + members: + - __init__ + - observation_spec + - action_spec + - reset + - step + - render diff --git a/docs/env_anim/sokoban.gif b/docs/env_anim/sokoban.gif new file mode 100644 index 000000000..fffcb85ff Binary files /dev/null and b/docs/env_anim/sokoban.gif differ diff --git a/docs/env_img/sokoban.png b/docs/env_img/sokoban.png new file mode 100644 index 000000000..fbc25a80e Binary files /dev/null and b/docs/env_img/sokoban.png differ diff --git a/docs/environments/sokoban.md b/docs/environments/sokoban.md new file mode 100644 index 000000000..41ca8be4b --- /dev/null +++ b/docs/environments/sokoban.md @@ -0,0 +1,76 @@ +# Sokoban Environment 👾 + +

+ +

+ +This is a Jax implementation of the _Sokoban_ puzzle, a dynamic box-pushing environment where the agent's goal is to place all boxes on their targets. This version follows the rules from the DeepMind paper on [Imagination Augmented Agents](https://arxiv.org/abs/1707.06203), with levels based on the Boxoban dataset from [Guez et al., 2018](https://github.com/deepmind/boxoban-levels)[[1]](#ref1). The graphical assets were taken from [gym-sokoban](https://github.com/mpSchrader/gym-sokoban) by Schrader, a diverse Sokoban library implementing many versions of the game in the OpenAI gym framework [[2]](#ref2). + +## Observation + +- `grid`: An Array (uint8) of shape `(10, 10, 2)`. It represents the variable grid (containing movable objects: boxes and the agent) and the fixed grid (containing fixed objects: walls and targets). +- `step_count`: An Array (int32) of shape `()`, representing the current number of steps in the episode. + +## Object Encodings + +| Object | Encoding | +|--------------|----------| +| Empty Space | 0 | +| Wall | 1 | +| Target | 2 | +| Agent | 3 | +| Box | 4 | + +## Actions + +The agent's action space is an Array (int32) with potential values of `[0,1,2,3]` (corresponding to `[Up, Down, Left, Right]`). If the agent attempts to move into a wall, off the grid, or push a box into a wall or off the grid, the grid state remains unchanged; however, the step count is incremented by one. Chained box pushes are not allowed and will result in no action. + +## Reward + +The reward function comprises: +- `-0.1` for each step taken in the environment. +- `+1` for each box moved onto a target location and `-1` for each box moved off a target location. +- `+10` upon successful placement of all four boxes on their targets. + +## Episode Termination + +The episode concludes when: +- The step limit of 120 is reached. +- All 4 boxes are placed on targets (i.e., the problem is solved). + +## Dataset + +The Boxoban dataset offers a collection of puzzle levels. Each level features four boxes and four targets. The dataset has three levels of difficulty: 'unfiltered', 'medium', and 'hard'. + +| Dataset Split | Number of Levels | +|---------------|------------------| +| Unfiltered (Training) | 900,000 | +| Unfiltered (Validation) | 100,000 | +| Unfiltered (Test) | 1000 | +| Medium (Training) | 450,000 | +| Medium (Validation) | 50,000 | +| Hard | 3332 | + + +The dataset generation procedure and more details can be found in Guez et al., 2018 [1]. + +## Graphics + +| Type | Graphic | +|------------------|----------------------------------------------------------------------------------------| +| Wall | ![Wall](../../jumanji/environments/routing/sokoban/imgs/wall.png) | +| Floor | ![Floor](../../jumanji/environments/routing/sokoban/imgs/floor.png) | +| Target | ![BoxTarget](../../jumanji/environments/routing/sokoban/imgs/box_target.png) | +| Box on Target | ![BoxTarget](../../jumanji/environments/routing/sokoban/imgs/box_on_target.png) | +| Box Off Target | ![BoxOffTarget](../../jumanji/environments/routing/sokoban/imgs/box.png) | +| Agent Off Target | ![PlayerOffTarget](../../jumanji/environments/routing/sokoban/imgs/agent.png) | +| Agent On Target | ![PlayerOnTarget](../../jumanji/environments/routing/sokoban/imgs/agent_on_target.png) | + +## Registered Versions 📖 + +- `Sokoban-v0`: Sokoban game with levels generated using DeepMind Boxoban dataset (unfiltered train). + +## References +[1] Guez, A., Mirza, M., Gregor, K., Kabra, R., Racaniere, S., Weber, T., Raposo, D., Santoro, A., Orseau, L., Eccles, T., Wayne, G., Silver, D., Lillicrap, T., Valdes, V. (2018). An investigation of Model-free planning: boxoban levels. Available at [https://github.com/deepmind/boxoban-levels](https://github.com/deepmind/boxoban-levels) + +[2] Schrader, M. (2018). Gym-sokoban. Available at [https://github.com/mpSchrader/gym-sokoban](https://github.com/mpSchrader/gym-sokoban) diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 11f15cb08..49b3fdcfd 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -73,7 +73,6 @@ kwargs={"generator": very_easy_sudoku_generator}, ) - ### # Packing Environments ### @@ -93,7 +92,6 @@ # Tetris - the game of tetris with a grid size of 10x10 and a time limit of 400. register(id="Tetris-v0", entry_point="jumanji.environments:Tetris") - ### # Routing Environments ### @@ -128,5 +126,7 @@ # TSP with 20 randomly generated cities and a dense reward function. register(id="TSP-v1", entry_point="jumanji.environments:TSP") +# Sokoban with deepmind dataset generator +register(id="Sokoban-v0", entry_point="jumanji.environments:Sokoban") # Pacman - minimal version of Atarti Pacman game register(id="PacMan-v0", entry_point="jumanji.environments:PacMan") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 444eaa616..239ef8f51 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -35,6 +35,7 @@ pac_man, robot_warehouse, snake, + sokoban, tsp, ) from jumanji.environments.routing.cleaner.env import Cleaner @@ -46,6 +47,7 @@ from jumanji.environments.routing.pac_man.env import PacMan from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.snake.env import Snake +from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.tsp.env import TSP diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index ba8115c16..1f0a91a6f 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -66,7 +66,7 @@ class Game2048(Environment[State]): ```python from jumanji.environments import Game2048 env = Game2048() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index 36970d7da..f20f05a88 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -73,7 +73,7 @@ class GraphColoring(Environment[State]): ```python from jumanji.environments import GraphColoring env = GraphColoring() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index a5d9e5f01..641ee48f9 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -78,7 +78,7 @@ class Minesweeper(Environment[State]): ```python from jumanji.environments import Minesweeper env = Minesweeper() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index bd01f9809..84a2dff44 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -72,7 +72,7 @@ class RubiksCube(Environment[State]): ```python from jumanji.environments import RubiksCube env = RubiksCube() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 629da1d1a..8e8f5ca12 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -63,7 +63,7 @@ class Sudoku(Environment[State]): ```python from jumanji.environments import Sudoku env = Sudoku() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 8cd6aa259..c3127c07e 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -103,7 +103,7 @@ class BinPack(Environment[State]): ```python from jumanji.environments import BinPack env = BinPack() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 1af25af87..0e2421524 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -80,7 +80,7 @@ class JobShop(Environment[State]): ```python from jumanji.environments import JobShop env = JobShop() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index ce6c3838c..573b83b32 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -73,7 +73,7 @@ class Knapsack(Environment[State]): ```python from jumanji.environments import Knapsack env = Knapsack() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 37e608663..8223225f1 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -66,7 +66,7 @@ class Tetris(Environment[State]): ```python from jumanji.environments import Tetris env = Tetris() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index f1e1410c1..6377bcb8f 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -71,7 +71,7 @@ class Cleaner(Environment[State]): ```python from jumanji.environments import Cleaner env = Cleaner() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index 7a139f31b..e76ba9da7 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -85,7 +85,7 @@ class Connector(Environment[State]): ```python from jumanji.environments import Connector env = Connector() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index 6af507519..ca6f76920 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -86,7 +86,7 @@ class CVRP(Environment[State]): ```python from jumanji.environments import CVRP env = CVRP() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 4c56eaedf..d0045144c 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -68,7 +68,7 @@ class Maze(Environment[State]): ```python from jumanji.environments import Maze env = Maze() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index cac71ffe7..485fa87ad 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -117,6 +117,17 @@ class MMST(Environment[State]): - INVALID_CHOICE = -1 - INVALID_TIE_BREAK = -2 - INVALID_ALREADY_TRAVERSED = -3 + + ```python + from jumanji.environments import MMST + env = MMST() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` """ def __init__( diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 23688829e..01aa933ce 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -64,6 +64,17 @@ class MultiCVRP(Environment[State]): [1] Zhang et al. (2020). "Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach". + + ```python + from jumanji.environments import MultiCVRP + env = MultiCVRP() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` """ def __init__( diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index 0231c3a94..fafd0cd20 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -81,7 +81,7 @@ class Snake(Environment[State]): ```python from jumanji.environments import Snake env = Snake() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/environments/routing/sokoban/__init__.py b/jumanji/environments/routing/sokoban/__init__.py new file mode 100644 index 000000000..ff8aa3a7c --- /dev/null +++ b/jumanji/environments/routing/sokoban/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.types import Observation, State diff --git a/jumanji/environments/routing/sokoban/constants.py b/jumanji/environments/routing/sokoban/constants.py new file mode 100644 index 000000000..782395e31 --- /dev/null +++ b/jumanji/environments/routing/sokoban/constants.py @@ -0,0 +1,37 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +# Translating actions to coordinate changes +MOVES = jnp.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) +NOOP = -1 + +# Object encodings +EMPTY = 0 +WALL = 1 +TARGET = 2 +AGENT = 3 +BOX = 4 +TARGET_AGENT = 5 +TARGET_BOX = 6 + +# Environment Variables +N_BOXES = 4 +GRID_SIZE = 10 + +# Reward Function +LEVEL_COMPLETE_BONUS = 10 +SINGLE_BOX_BONUS = 1 +STEP_BONUS = -0.1 diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py new file mode 100644 index 000000000..c56fcbf89 --- /dev/null +++ b/jumanji/environments/routing/sokoban/env.py @@ -0,0 +1,575 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib.animation + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.routing.sokoban.constants import ( + AGENT, + BOX, + EMPTY, + GRID_SIZE, + MOVES, + N_BOXES, + NOOP, + TARGET, + TARGET_AGENT, + TARGET_BOX, + WALL, +) +from jumanji.environments.routing.sokoban.generator import ( + Generator, + HuggingFaceDeepMindGenerator, +) +from jumanji.environments.routing.sokoban.reward import DenseReward, RewardFn +from jumanji.environments.routing.sokoban.types import Observation, State +from jumanji.environments.routing.sokoban.viewer import BoxViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class Sokoban(Environment[State]): + """A JAX implementation of the 'Sokoban' game from deepmind. + + - observation: `Observation` + - grid: jax array (uint8) of shape (num_rows, num_cols, 2) + Array that includes information about the agent, boxes, and + targets in the game. + - step_count: jax array (int32) of shape () + current number of steps in the episode. + + - action: jax array (int32) of shape () + [0,1,2,3] -> [Up, Right, Down, Left]. + + - reward: jax array (float) of shape () + A reward of 1.0 is given for each box placed on a target and -1 + when removed from a target and -0.1 for each timestep. + 10 is awarded when all boxes are on targets. + + - episode termination: + - if the time limit is reached. + - if all boxes are on targets. + + - state: `State` + - key: jax array (uint32) of shape (2,) used for auto-reset + - fixed_grid: jax array (uint8) of shape (num_rows, num_cols) + array indicating the walls and targets in the level. + - variable_grid: jax array (uint8) of shape (num_rows, num_cols) + array indicating the current location of the agent and boxes. + - agent_location: jax array (int32) of shape (2,) + the agent's current location. + - step_count: jax array (int32) of shape () + current number of steps in the episode. + + ```python + from jumanji.environments import Sokoban + from jumanji.environments.routing.sokoban.generator import + HuggingFaceDeepMindGenerator, + + env_train = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + env_test = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-test", + proportion_of_files=1, + ) + ) + + # Train... + + ``` + key_train = jax.random.PRNGKey(0) + state, timestep = jax.jit(env_train.reset)(key_train) + env_train.render(state) + action = env_train.action_spec().generate_value() + state, timestep = jax.jit(env_train.step)(state, action) + env_train.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, + viewer: Optional[Viewer] = None, + time_limit: int = 120, + ) -> None: + """ + Instantiates a `Sokoban` environment with a specific generator, + time limit, and viewer. + + Args: + generator: `Generator` whose `__call__` instantiates an environment + instance (an initial state). Implemented options are [`ToyGenerator`, + `DeepMindGenerator`, and `HuggingFaceDeepMindGenerator`]. + Defaults to `HuggingFaceDeepMindGenerator` with + `dataset_name="unfiltered-train", proportion_of_files=1`. + time_limit: int, max steps for the environment, defaults to 120. + viewer: 'Viewer' object, used to render the environment. + If not provided, defaults to`BoxViewer`. + """ + + self.num_rows = GRID_SIZE + self.num_cols = GRID_SIZE + self.shape = (self.num_rows, self.num_cols) + self.time_limit = time_limit + + self.generator = generator or HuggingFaceDeepMindGenerator( + "unfiltered-train", + proportion_of_files=1, + ) + + self._viewer = viewer or BoxViewer( + name="Sokoban", + grid_combine=self.grid_combine, + ) + self.reward_fn = reward_fn or DenseReward() + + def __repr__(self) -> str: + """ + Returns a printable representation of the Sokoban environment. + + Returns: + str: A string representation of the Sokoban environment. + """ + return "\n".join( + [ + "Bokoban environment:", + f" - num_rows: {self.num_rows}", + f" - num_cols: {self.num_cols}", + f" - time_limit: {self.time_limit}", + f" - generator: {self.generator}", + ] + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """ + Resets the environment by calling the instance generator for a + new instance. + + Args: + key: random key used to sample new Sokoban problem. + + Returns: + state: `State` object corresponding to the new state of the + environment after a reset. + timestep: `TimeStep` object corresponding the first timestep + returned by the environment after a reset. + """ + + generator_key, key = jax.random.split(key) + + state = self.generator(generator_key) + + timestep = restart( + self._state_to_observation(state), + extras=self._get_extras(state), + ) + + return state, timestep + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """ + Executes one timestep of the environment's dynamics. + + Args: + state: 'State' object representing the current state of the + environment. + action: Array (int32) of shape (). + - 0: move up. + - 1: move down. + - 2: move left. + - 3: move right. + + Returns: + state, timestep: next state of the environment and timestep to be + observed. + """ + + # switch to noop if action will have no impact on variable grid + action = self.detect_noop_action( + state.variable_grid, state.fixed_grid, action, state.agent_location + ) + + next_variable_grid, next_agent_location = jax.lax.cond( + jnp.all(action == NOOP), + lambda: (state.variable_grid, state.agent_location), + lambda: self.move_agent(state.variable_grid, action, state.agent_location), + ) + + next_state = State( + key=state.key, + fixed_grid=state.fixed_grid, + variable_grid=next_variable_grid, + agent_location=next_agent_location, + step_count=state.step_count + 1, + ) + + target_reached = self.level_complete(next_state) + time_limit_exceeded = next_state.step_count >= self.time_limit + + done = jnp.logical_or(target_reached, time_limit_exceeded) + + reward = jnp.asarray(self.reward_fn(state, action, next_state), float) + + observation = self._state_to_observation(next_state) + + extras = self._get_extras(next_state) + + timestep = jax.lax.cond( + done, + lambda: termination( + reward=reward, + observation=observation, + extras=extras, + ), + lambda: transition( + reward=reward, + observation=observation, + extras=extras, + ), + ) + + return next_state, timestep + + def observation_spec(self) -> specs.Spec[Observation]: + """ + Returns the specifications of the observation of the `Sokoban` + environment. + + Returns: + specs.Spec[Observation]: The specifications of the observations. + """ + grid = specs.BoundedArray( + shape=(self.num_rows, self.num_cols, 2), + dtype=jnp.uint8, + minimum=0, + maximum=4, + name="grid", + ) + step_count = specs.Array((), jnp.int32, "step_count") + return specs.Spec( + Observation, + "ObservationSpec", + grid=grid, + step_count=step_count, + ) + + def action_spec(self) -> specs.DiscreteArray: + """ + Returns the action specification for the Sokoban environment. + There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + + Returns: + specs.DiscreteArray: Discrete action specifications. + """ + return specs.DiscreteArray(4, name="action", dtype=jnp.int32) + + def _state_to_observation(self, state: State) -> Observation: + """Maps an environment state to an observation. + + Args: + state: `State` object containing the dynamics of the environment. + + Returns: + The observation derived from the state. + """ + + total_grid = jnp.stack([state.variable_grid, state.fixed_grid], axis=-1) + + return Observation( + grid=total_grid, + step_count=state.step_count, + ) + + def _get_extras(self, state: State) -> Dict: + """ + Computes extras metrics to be returned within the timestep. + + Args: + state: 'State' object representing the current state of the + environment. + + Returns: + extras: Dict object containing current proportion of boxes on + targets and whether the problem is solved. + """ + num_boxes_on_targets = self.reward_fn.count_targets(state) + total_num_boxes = N_BOXES + extras = { + "prop_correct_boxes": num_boxes_on_targets / total_num_boxes, + "solved": num_boxes_on_targets == 4, + } + return extras + + def grid_combine( + self, variable_grid: chex.Array, fixed_grid: chex.Array + ) -> chex.Array: + """ + Combines the variable grid and fixed grid into one single grid + representation of the current Sokoban state required for visual + representation of the Sokoban state. Takes care of two possible + overlaps of fixed and variable entries (an agent on a target or a box + on a target), introducing two additional encodings. + + Args: + variable_grid: Array (uint8) of shape (num_rows, num_cols). + fixed_grid: Array (uint8) of shape (num_rows, num_cols). + + Returns: + full_grid: Array (uint8) of shape (num_rows, num_cols, 2). + """ + + mask_target_agent = jnp.logical_and( + fixed_grid == TARGET, + variable_grid == AGENT, + ) + + mask_target_box = jnp.logical_and( + fixed_grid == TARGET, + variable_grid == BOX, + ) + + single_grid = jnp.where( + mask_target_agent, + TARGET_AGENT, + jnp.where( + mask_target_box, + TARGET_BOX, + jnp.maximum(variable_grid, fixed_grid), + ), + ).astype(jnp.uint8) + + return single_grid + + def level_complete(self, state: State) -> chex.Array: + """ + Checks if the sokoban level is complete. + + Args: + state: `State` object representing the current state of the environment. + + Returns: + complete: Boolean indicating whether the level is complete + or not. + """ + return self.reward_fn.count_targets(state) == N_BOXES + + def check_space( + self, + grid: chex.Array, + location: chex.Array, + value: int, + ) -> chex.Array: + """ + Checks if a specific location in the grid contains a given value. + + Args: + grid: Array (uint8) shape (num_rows, num_cols) The grid to check. + location: Tuple size 2 of Array (int32) shape () containing the x + and y coodinate of the location to check in the grid. + value: int The value to look for. + + Returns: + present: Array (bool) shape () indicating whether the location + in the grid contains the given value or not. + """ + + return grid[tuple(location)] == value + + def in_grid(self, coordinates: chex.Array) -> chex.Array: + """ + Checks if given coordinates are within the grid size. + + Args: + coordinates: Array (uint8) shape (num_rows, num_cols) The + coordinates to check. + Returns: + in_grid: Array (bool) shape () Boolean indicating whether the + coordinates are within the grid. + """ + return jnp.all((0 <= coordinates) & (coordinates < GRID_SIZE)) + + def detect_noop_action( + self, + variable_grid: chex.Array, + fixed_grid: chex.Array, + action: chex.Array, + agent_location: chex.Array, + ) -> chex.Array: + """ + Masks actions to -1 that have no effect on the variable grid. + Determines if there is space in the destination square or if + there is a box in the destination square, it determines if the box + destination square is valid. + + Args: + variable_grid: Array (uint8) shape (num_rows, num_cols). + fixed_grid Array (uint8) shape (num_rows, num_cols) . + action: Array (int32) shape () The action to check. + + Returns: + updated_action: Array (int32) shape () The updated action after + detecting noop action. + """ + + new_location = agent_location + MOVES[action].squeeze() + + valid_destination = self.check_space( + fixed_grid, new_location, WALL + ) | ~self.in_grid(new_location) + + updated_action = jax.lax.select( + valid_destination, + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + jax.lax.select( + self.check_space(variable_grid, new_location, BOX), + self.update_box_push_action( + fixed_grid, + variable_grid, + new_location, + action, + ), + action, + ), + ) + + return updated_action + + def update_box_push_action( + self, + fixed_grid: chex.Array, + variable_grid: chex.Array, + new_location: chex.Array, + action: chex.Array, + ) -> chex.Array: + """ + Masks actions to -1 if pushing the box is not a valid move. If it + would be pushed out of the grid or the resulting square + is either a wall or another box. + + Args: + fixed_grid: Array (uint8) shape (num_rows, num_cols) The fixed grid. + variable_grid: Array (uint8) shape (num_rows, num_cols) The + variable grid. + new_location: Array (int32) shape (2,) The new location of the agent. + action: Array (int32) shape () The action to be executed. + + Returns: + updated_action: Array (int32) shape () The updated action after + checking if pushing the box is a valid move. + """ + + return jax.lax.select( + self.check_space( + variable_grid, + new_location + MOVES[action].squeeze(), + BOX, + ) + | ~self.in_grid(new_location + MOVES[action].squeeze()), + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + jax.lax.select( + self.check_space( + fixed_grid, + new_location + MOVES[action].squeeze(), + WALL, + ), + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + action, + ), + ) + + def move_agent( + self, + variable_grid: chex.Array, + action: chex.Array, + current_location: chex.Array, + ) -> Tuple[chex.Array, chex.Array]: + """ + Executes the movement of the agent specified by the action and + executes the movement of a box if present at the destination. + + Args: + variable_grid: Array (uint8) shape (num_rows, num_cols) + action: Array (int32) shape () The action to take. + current_location: Array (int32) shape (2,) + + Returns: + next_variable_grid: Array (uint8) shape (num_rows, num_cols) + next_location: Array (int32) shape (2,) + """ + + next_location = current_location + MOVES[action] + box_location = next_location + MOVES[action] + + # remove agent from current location + next_variable_grid = variable_grid.at[tuple(current_location)].set(EMPTY) + + # either move agent or move agent and box + + next_variable_grid = jax.lax.select( + self.check_space(variable_grid, next_location, BOX), + next_variable_grid.at[tuple(next_location)] + .set(AGENT) + .at[tuple(box_location)] + .set(BOX), + next_variable_grid.at[tuple(next_location)].set(AGENT), + ) + + return next_variable_grid, next_location + + def render(self, state: State) -> None: + """ + Renders the current state of Sokoban. + + Args: + state: 'State' object , the current state to be rendered. + """ + + self._viewer.render(state=state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """ + Creates an animated gif of the Sokoban environment based on the + sequence of states. + + Args: + states: Sequence of 'State' object + interval: int, The interval between frames in the animation. + Defaults to 200. + save_path: str The path where to save the animation. If not + provided, the animation is not saved. + + Returns: + animation: 'matplotlib.animation.FuncAnimation'. + """ + return self._viewer.animate(states, interval, save_path) diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py new file mode 100644 index 000000000..8c3d8da93 --- /dev/null +++ b/jumanji/environments/routing/sokoban/env_test.py @@ -0,0 +1,217 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.constants import AGENT, BOX, TARGET, WALL +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import ( + DeepMindGenerator, + SimpleSolveGenerator, +) +from jumanji.environments.routing.sokoban.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.types import TimeStep + + +@pytest.fixture(scope="session") +def sokoban() -> Sokoban: + env = Sokoban( + generator=DeepMindGenerator( + difficulty="unfiltered", + split="train", + proportion_of_files=0.005, + ) + ) + return env + + +@pytest.fixture(scope="session") +def sokoban_simple() -> Sokoban: + env = Sokoban(generator=SimpleSolveGenerator()) + return env + + +def test_sokoban__reset(sokoban: Sokoban) -> None: + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(sokoban.reset, n=1)) + key = jax.random.PRNGKey(0) + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + assert state.step_count == 0 + assert timestep.observation.step_count == 0 + key2 = jax.random.PRNGKey(1) + state2, timestep2 = reset_fn(key2) + assert not jnp.array_equal(state2.fixed_grid, state.fixed_grid) + assert not jnp.array_equal(state2.variable_grid, state.variable_grid) + + +def test_sokoban__multi_step(sokoban: Sokoban) -> None: + """Validates the jitted step of the sokoban environment.""" + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1)) + + # Repeat test for 5 different state initializations + for j in range(5): + step_count = 0 + key = jax.random.PRNGKey(j) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban.reset(reset_key) + + # Repeating random step 120 times + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + # Check step_count increases after each step + step_count += 1 + assert state.step_count == step_count + assert timestep.observation.step_count == step_count + + # Check that the fixed part of the state has not changed + assert jnp.array_equal(state.fixed_grid, state.fixed_grid) + + # Check that there are always four boxes in the variable grid and 0 elsewhere + num_boxes = jnp.sum(state.variable_grid == BOX) + assert num_boxes == jnp.array(4, jnp.int32) + + num_boxes = jnp.sum(state.fixed_grid == BOX) + assert num_boxes == jnp.array(0, jnp.int32) + + # Check that there are always 4 targets in the fixed grid and 0 elsewhere + num_targets = jnp.sum(state.variable_grid == TARGET) + assert num_targets == jnp.array(0, jnp.int32) + + num_targets = jnp.sum(state.fixed_grid == TARGET) + assert num_targets == jnp.array(4, jnp.int32) + + # Check that there is one agent in variable grid and 0 elsewhere + num_agents = jnp.sum(state.variable_grid == AGENT) + assert num_agents == jnp.array(1, jnp.int32) + + num_agents = jnp.sum(state.fixed_grid == AGENT) + assert num_agents == jnp.array(0, jnp.int32) + + # Check that the grid size remains constant + assert state.fixed_grid.shape == (10, 10) + + # Check the agent is never in the same location as a wall + mask_agent = state.variable_grid == AGENT + mask_wall = state.fixed_grid == WALL + num_agents_on_wall = jnp.sum(mask_agent & mask_wall) + assert num_agents_on_wall == jnp.array(0, jnp.int32) + + # Check the boxes are never on a wall + mask_boxes = state.variable_grid == BOX + mask_wall = state.fixed_grid == WALL + num_agents_on_wall = jnp.sum(mask_boxes & mask_wall) + assert num_agents_on_wall == jnp.array(0, jnp.int32) + + +def test_sokoban__termination_timelimit(sokoban: Sokoban) -> None: + """Check that with random actions the environment terminates after + 120 steps""" + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1)) + + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban.reset(reset_key) + + for _ in range(119): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + assert not timestep.last() + + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + assert timestep.last() + + +def test_sokoban__termination_solved(sokoban_simple: Sokoban) -> None: + """Check that with correct sequence of actions to solve a trivial problem, + the environment terminates""" + + correct_actions = [0, 2, 1] * 3 + [0] + wrong_actions = [0, 2, 1] * 3 + [2] + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + # Check that environment does terminate with right series of actions + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for action in correct_actions: + assert not timestep.last() + + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + assert timestep.last() + + # Check that environment does not terminate with wrong series of actions + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for action in wrong_actions: + assert not timestep.last() + + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + assert not timestep.last() + + +def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None: + """Check the reward function is correct when solving the trivial problem. + Every step should give -0.1, each box added to a target adds 1 and + solving adds an additional 10""" + + # Correct actions that lead to placing a box every 3 actions + correct_actions = [0, 2, 1] * 3 + [0] + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for i, action in enumerate(correct_actions): + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + if i % 3 == 0 and i != 9: + assert timestep.reward == jnp.array(0.9, jnp.float32) + elif i != 9: + assert timestep.reward == jnp.array(-0.1, jnp.float32) + else: + assert timestep.reward == jnp.array(10.9, jnp.float32) + + +def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(sokoban) diff --git a/jumanji/environments/routing/sokoban/generator.py b/jumanji/environments/routing/sokoban/generator.py new file mode 100644 index 000000000..e3ace0a4a --- /dev/null +++ b/jumanji/environments/routing/sokoban/generator.py @@ -0,0 +1,448 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import os +import zipfile +from os import listdir +from os.path import isfile, join +from typing import List, Tuple + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import requests +from huggingface_hub import hf_hub_download +from tqdm import tqdm + +from jumanji.environments.routing.sokoban.constants import AGENT +from jumanji.environments.routing.sokoban.types import State + + +class Generator(abc.ABC): + """Defines the abstract `Generator` base class. A `Generator` is responsible + for generating a problem instance when the environment is reset. + """ + + def __init__( + self, + ) -> None: + """ """ + + self._fixed_grids: chex.Array + self._variable_grids: chex.Array + + @abc.abstractmethod + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a problem instance. + + Args: + key: the Jax random number generation key. + + Returns: + state: the generated problem instance. + """ + + def get_agent_coordinates(self, grid: chex.Array) -> chex.Array: + """Extracts the coordinates of the agent from a given grid with the + assumption there is only one agent in the grid. + + Args: + grid: Array (uint8) of shape (num_rows, num_cols) + + Returns: + location: (int32) of shape (2,) + """ + + coordinates = jnp.where(grid == AGENT, size=1) + + x_coord = jnp.squeeze(coordinates[0]) + y_coord = jnp.squeeze(coordinates[1]) + + return jnp.array([x_coord, y_coord]) + + +class DeepMindGenerator(Generator): + def __init__( + self, + difficulty: str, + split: str, + proportion_of_files: float = 1.0, + verbose: bool = False, + ) -> None: + self.difficulty = difficulty + self.verbose = verbose + self.proportion_of_files = proportion_of_files + + # Set the cache path to user's home directory's .cache sub-directory + self.cache_path = os.path.join( + os.path.expanduser("~"), ".cache", "sokoban_dataset" + ) + + # Downloads data if not already downloaded + self._download_data() + + self.train_data_dir = os.path.join( + self.cache_path, "boxoban-levels-master", self.difficulty + ) + + if self.difficulty in ["unfiltered", "medium"]: + if self.difficulty == "medium" and split == "test": + raise Exception( + "not a valid Deepmind Boxoban difficulty split" "combination" + ) + self.train_data_dir = os.path.join( + self.train_data_dir, + split, + ) + + # Generates the dataset of sokoban levels + self._fixed_grids, self._variable_grids = self._generate_dataset() + + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a random Boxoban problem from the Deepmind dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + idx = jax.random.randint( + idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] + ) + fixed_grid = self._fixed_grids.take(idx, axis=0) + variable_grid = self._variable_grids.take(idx, axis=0) + + initial_agent_location = self.get_agent_coordinates(variable_grid) + + state = State( + key=key, + fixed_grid=fixed_grid, + variable_grid=variable_grid, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + def _generate_dataset( + self, + ) -> Tuple[chex.Array, chex.Array]: + """Parses the text files to generate a jax arrays (fixed and variable + grids representing the Boxoban dataset + + Returns: + fixed_grid: Array (uint8) shape (dataset_size, num_rows, num_cols) + the fixed components of the problem. + variable_grid: Array (uint8) shape (dataset_size, num_rows, + num_cols) the variable components of the problem. + """ + + all_files = [ + f + for f in listdir(self.train_data_dir) + if isfile(join(self.train_data_dir, f)) + ] + # Only keep a few files if specified + all_files = all_files[: int(self.proportion_of_files * len(all_files))] + + fixed_grids_list: List[chex.Array] = [] + variable_grids_list: List[chex.Array] = [] + + for file in all_files: + + source_file = join(self.train_data_dir, file) + current_map: List[str] = [] + # parses a game file containing multiple games + with open(source_file, "r") as sf: + for line in sf.readlines(): + if ";" in line and current_map: + fixed_grid, variable_grid = convert_level_to_array(current_map) + + fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8)) + variable_grids_list.append( + jnp.array(variable_grid, dtype=jnp.uint8) + ) + + current_map = [] + if "#" == line[0]: + current_map.append(line.strip()) + + fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8)) + variable_grids_list.append(jnp.array(variable_grid, dtype=jnp.uint8)) + + fixed_grids = jnp.asarray(fixed_grids_list, jnp.uint8) + variable_grids = jnp.asarray(variable_grids_list, jnp.uint8) + + return fixed_grids, variable_grids + + def _download_data(self) -> None: + """Downloads the deepmind boxoban dataset from github into text files""" + + # Check if the cache directory exists, if not, create it + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + # Check if the dataset is already downloaded in the cache + dataset_path = os.path.join(self.cache_path, "boxoban-levels-master") + if not os.path.exists(dataset_path): + url = "https://github.com/deepmind/boxoban-levels/archive/master.zip" + if self.verbose: + print("Boxoban: Pregenerated levels not downloaded.") + print('Starting download from "{}"'.format(url)) + + response = requests.get(url, stream=True) + + if response.status_code != 200: + raise Exception("Could not download levels") + + path_to_zip_file = os.path.join( + self.cache_path, "boxoban_levels-master.zip" + ) + with open(path_to_zip_file, "wb") as handle: + for data in tqdm(response.iter_content()): + handle.write(data) + + with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: + zip_ref.extractall(self.cache_path) + + +class HuggingFaceDeepMindGenerator(Generator): + """Instance generator that generates a random problem from the DeepMind + Boxoban dataset a popular dataset for comparing Reinforcement Learning + algorithms and Planning Algorithms. The dataset has unfiltered, medium and + hard versions. The unfiltered dataset contain train, test and valid + splits. The Medium has train and valid splits available. And the hard set + contains just a small number of problems. The problems are all guaranteed + to be solvable. + """ + + def __init__( + self, + dataset_name: str, + proportion_of_files: float = 1.0, + ) -> None: + """Instantiates a `DeepMindGenerator`. + + Args: + dataset_name: the name of the dataset to use. Choices are: + - unfiltered-train, + - unfiltered-valid, + - unfiltered-test, + - medium-train, + - medium-test, + - hard. + proportion_of_files: float between (0,1) for the proportion of + files to use in the dataset . + """ + + self.dataset_name = dataset_name + self.proportion_of_files = proportion_of_files + + dataset_file = hf_hub_download( + repo_id="InstaDeepAI/boxoban-levels", + filename=f"{dataset_name}.npy", + ) + + with open(dataset_file, "rb") as f: + dataset = np.load(f) + + # Convert to jax arrays and resize using proportion_of_files + length = int(proportion_of_files * dataset.shape[0]) + self._fixed_grids = jnp.asarray(dataset[:length, ..., 0], jnp.uint8) + self._variable_grids = jnp.asarray(dataset[:length, ..., 1], jnp.uint8) + + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a random Boxoban problem from the Deepmind dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + idx = jax.random.randint( + idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] + ) + fixed_grid = self._fixed_grids.take(idx, axis=0) + variable_grid = self._variable_grids.take(idx, axis=0) + + initial_agent_location = self.get_agent_coordinates(variable_grid) + + state = State( + key=key, + fixed_grid=fixed_grid, + variable_grid=variable_grid, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +class ToyGenerator(Generator): + def __call__( + self, + rng_key: chex.PRNGKey, + ) -> State: + """Generate a random Boxoban problem from the toy 2 problem dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + + level1 = [ + "##########", + "# @ #", + "# $ . #", + "# $# . #", + "# .#$ # ", + "# . # $ # ", + "# #", + "##########", + "##########", + "##########", + ] + + level2 = [ + "##########", + "# #", + "#$ # . #", + "# # $ # .#", + "# .# $ #", + "# @ # . $#", + "# #", + "##########", + "##########", + "##########", + ] + + game1_fixed, game1_variable = convert_level_to_array(level1) + game2_fixed, game2_variable = convert_level_to_array(level2) + + games_fixed = jnp.stack([game1_fixed, game2_fixed]) + games_variable = jnp.stack([game1_variable, game2_variable]) + + game_index = jax.random.randint( + key=idx_key, + shape=(), + minval=0, + maxval=games_fixed.shape[0], + ) + + initial_agent_location = self.get_agent_coordinates(games_variable[game_index]) + + state = State( + key=key, + fixed_grid=games_fixed[game_index], + variable_grid=games_variable[game_index], + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +class SimpleSolveGenerator(Generator): + def __call__( + self, + key: chex.PRNGKey, + ) -> State: + """Generate a trivial Boxoban problem. + + Args: + key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + level1 = [ + "##########", + "# ##", + "# .... #", + "# $$$$ ##", + "# @ # #", + "# # # ", + "# #", + "##########", + "##########", + "##########", + ] + + game_fixed, game_variable = convert_level_to_array(level1) + + initial_agent_location = self.get_agent_coordinates(game_variable) + + state = State( + key=key, + fixed_grid=game_fixed, + variable_grid=game_variable, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +def convert_level_to_array(level: List[str]) -> Tuple[chex.Array, chex.Array]: + """Converts text representation of levels to a tuple of Jax arrays + representing the fixed elements of the Boxoban problem and the variable + elements + + Args: + level: List of str representing a boxoban level. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) + the fixed components of the problem. + variable_grid: Array (uint8) shape (num_rows, + num_cols) the variable components of the problem. + """ + + # Define the mappings + mapping = { + "#": (1, 0), + ".": (2, 0), + "@": (0, 3), + "$": (0, 4), + " ": (0, 0), # empty cell + } + + fixed = [[mapping[cell][0] for cell in row] for row in level] + variable = [[mapping[cell][1] for cell in row] for row in level] + + return jnp.array(fixed, jnp.uint8), jnp.array(variable, jnp.uint8) diff --git a/jumanji/environments/routing/sokoban/generator_test.py b/jumanji/environments/routing/sokoban/generator_test.py new file mode 100644 index 000000000..0c766b632 --- /dev/null +++ b/jumanji/environments/routing/sokoban/generator_test.py @@ -0,0 +1,196 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import List + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import ( + DeepMindGenerator, + HuggingFaceDeepMindGenerator, +) + + +def test_sokoban__hugging_generator_creation() -> None: + """checks we can create datasets for all valid boxoban datasets and + perform a jitted step""" + + datasets = [ + "unfiltered-train", + "unfiltered-test", + "unfiltered-valid", + "medium-train", + "medium-valid", + "hard", + ] + + for dataset in datasets: + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name=dataset, + proportion_of_files=1, + ) + ) + + print(env.generator._fixed_grids.shape) + + step_fn = jax.jit(chex.assert_max_traces(env.step, n=1)) + + key = jax.random.PRNGKey(0) + state, timestep = env.reset(key) + + step_count = 0 + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + # Check step_count increases after each step + step_count += 1 + assert state.step_count == step_count + assert timestep.observation.step_count == step_count + + +def test_sokoban__hugging_generator_different_problems() -> None: + """checks that resetting with different keys leads to different problems""" + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + key1 = jax.random.PRNGKey(0) + state1, timestep1 = env.reset(key1) + + key2 = jax.random.PRNGKey(1) + state2, timestep2 = env.reset(key2) + + # Check that different resets lead to different problems + assert not jnp.array_equal(state2.fixed_grid, state1.fixed_grid) + assert not jnp.array_equal(state2.variable_grid, state1.variable_grid) + + +def test_sokoban__hugging_generator_same_problems() -> None: + """checks that resettting with the same key leads to the same problems""" + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + key1 = jax.random.PRNGKey(0) + state1, timestep1 = env.reset(key1) + + key2 = jax.random.PRNGKey(0) + state2, timestep2 = env.reset(key2) + + assert jnp.array_equal(state2.fixed_grid, state1.fixed_grid) + assert jnp.array_equal(state2.variable_grid, state1.variable_grid) + + +def test_sokoban__hugging_generator_proportion_of_problems() -> None: + """checks that generator initialises correct number of problems""" + + chex.clear_trace_counter() + + unfiltered_dataset_size = 900000 + + generator_full = HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + + assert jnp.array_equal( + generator_full._fixed_grids.shape, + (unfiltered_dataset_size, 10, 10), + ) + + generator_half_full = HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=0.5, + ) + + assert jnp.array_equal( + generator_half_full._fixed_grids.shape, + (unfiltered_dataset_size / 2, 10, 10), + ) + + +def test_sokoban__deepmind_generator_creation() -> None: + """checks we can create datasets for all valid boxoban datasets""" + + # Different datasets with varying proportion of files to keep rutime low + valid_datasets: List[List] = [ + ["unfiltered", "train", 0.01], + ["unfiltered", "test", 1], + ["unfiltered", "valid", 0.02], + ["medium", "train", 0.01], + ["medium", "valid", 0.02], + ["hard", None, 1], + ] + + for dataset in valid_datasets: + + chex.clear_trace_counter() + + env = Sokoban( + generator=DeepMindGenerator( + difficulty=dataset[0], + split=dataset[1], + proportion_of_files=dataset[2], + ) + ) + + assert env.generator._fixed_grids.shape[0] > 0 + + +def test_sokoban__deepmind_invalid_creation() -> None: + """checks that asking for invalid difficulty, split, proportion leads to + exception""" + + # Different datasets with varying proportion of files to keep rutime low + valid_datasets: List[List] = [ + ["medium", "test", 0.01], + ["mediumy", "train", 0.01], + ["hardy", "train", 0.01], + ["unfiltered", None, 0.01], + ] + + for dataset in valid_datasets: + + chex.clear_trace_counter() + + with pytest.raises(Exception): + _ = Sokoban( + generator=DeepMindGenerator( + difficulty=dataset[0], + split=dataset[1], + proportion_of_files=dataset[2], + ) + ) diff --git a/jumanji/environments/routing/sokoban/imgs/agent.png b/jumanji/environments/routing/sokoban/imgs/agent.png new file mode 100644 index 000000000..00298ce9b Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/agent.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/agent_on_target.png b/jumanji/environments/routing/sokoban/imgs/agent_on_target.png new file mode 100644 index 000000000..3e8310d25 Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/agent_on_target.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/box.png b/jumanji/environments/routing/sokoban/imgs/box.png new file mode 100644 index 000000000..9a2497df0 Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/box_on_target.png b/jumanji/environments/routing/sokoban/imgs/box_on_target.png new file mode 100644 index 000000000..74629af03 Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box_on_target.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/box_target.png b/jumanji/environments/routing/sokoban/imgs/box_target.png new file mode 100644 index 000000000..41cbc8f4e Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/box_target.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/floor.png b/jumanji/environments/routing/sokoban/imgs/floor.png new file mode 100644 index 000000000..79f053277 Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/floor.png differ diff --git a/jumanji/environments/routing/sokoban/imgs/wall.png b/jumanji/environments/routing/sokoban/imgs/wall.png new file mode 100644 index 000000000..8da06c8d6 Binary files /dev/null and b/jumanji/environments/routing/sokoban/imgs/wall.png differ diff --git a/jumanji/environments/routing/sokoban/reward.py b/jumanji/environments/routing/sokoban/reward.py new file mode 100644 index 000000000..775122293 --- /dev/null +++ b/jumanji/environments/routing/sokoban/reward.py @@ -0,0 +1,120 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban.constants import ( + BOX, + LEVEL_COMPLETE_BONUS, + N_BOXES, + SINGLE_BOX_BONUS, + STEP_BONUS, + TARGET, +) +from jumanji.environments.routing.sokoban.types import State + + +class RewardFn(abc.ABC): + @abc.abstractmethod + def __call__( + self, + state: State, + action: chex.Numeric, + next_state: State, + ) -> chex.Numeric: + """Compute the reward based on the current state, + the chosen action, the next state. + """ + + def count_targets(self, state: State) -> chex.Array: + """ + Calculates the number of boxes on targets. + + Args: + state: `State` object representing the current state of the + environment. + + Returns: + n_targets: Array (int32) of shape () specifying the number of boxes + on targets. + """ + + mask_box = state.variable_grid == BOX + mask_target = state.fixed_grid == TARGET + + num_boxes_on_targets = jnp.sum(mask_box & mask_target) + + return num_boxes_on_targets + + +class SparseReward(RewardFn): + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + ) -> chex.Array: + """ + Implements the sparse reward function in the Sokoban environment. + + Args: + state: `State` object The current state of the environment. + action: Array (int32) shape () representing the action taken. + next_state: `State` object The next state of the environment. + + Returns: + reward: Array (float32) of shape () specifying the reward received + at transition + """ + + next_num_box_target = self.count_targets(next_state) + + level_completed = next_num_box_target == N_BOXES + + return LEVEL_COMPLETE_BONUS * level_completed + + +class DenseReward(RewardFn): + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + ) -> chex.Array: + """ + Implements the dense reward function in the Sokoban environment. + + Args: + state: `State` object The current state of the environment. + action: Array (int32) shape () representing the action taken. + next_state: `State` object The next state of the environment. + + Returns: + reward: Array (float32) of shape () specifying the reward received + at transition + """ + + num_box_target = self.count_targets(state) + next_num_box_target = self.count_targets(next_state) + + level_completed = next_num_box_target == N_BOXES + + return ( + SINGLE_BOX_BONUS * (next_num_box_target - num_box_target) + + LEVEL_COMPLETE_BONUS * level_completed + + STEP_BONUS + ) diff --git a/jumanji/environments/routing/sokoban/reward_test.py b/jumanji/environments/routing/sokoban/reward_test.py new file mode 100644 index 000000000..c8ae5b08b --- /dev/null +++ b/jumanji/environments/routing/sokoban/reward_test.py @@ -0,0 +1,74 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import SimpleSolveGenerator +from jumanji.types import TimeStep + + +@pytest.fixture(scope="session") +def sokoban_simple() -> Sokoban: + env = Sokoban(generator=SimpleSolveGenerator()) + return env + + +def test_sokoban__reward_function_random(sokoban_simple: Sokoban) -> None: + """Check the reward function is correct when randomly acting in the + trivial problem, where accidently pushing boxes onto targets is likely. + Every step should give -0.1, each box pushed on adds 1 , each box removed + on takes away 1 ,solving adds an additional 10""" + + def check_correct_reward( + timestep: TimeStep, + num_boxes_on_targets_new: chex.Array, + num_boxes_on_targets: chex.Array, + ) -> None: + + if num_boxes_on_targets_new == jnp.array(4, jnp.int32): + assert timestep.reward == jnp.array(10.9, jnp.float32) + elif num_boxes_on_targets_new - num_boxes_on_targets > jnp.array(0, jnp.int32): + assert timestep.reward == jnp.array(0.9, jnp.float32) + elif num_boxes_on_targets_new - num_boxes_on_targets < jnp.array(0, jnp.int32): + assert timestep.reward == jnp.array(-1.1, jnp.float32) + else: + assert timestep.reward == jnp.array(-0.1, jnp.float32) + + for i in range(5): + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + key = jax.random.PRNGKey(i) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + num_boxes_on_targets = sokoban_simple.reward_fn.count_targets(state) + + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + num_boxes_on_targets_new = sokoban_simple.reward_fn.count_targets(state) + + check_correct_reward( + timestep, num_boxes_on_targets_new, num_boxes_on_targets + ) + + num_boxes_on_targets = num_boxes_on_targets_new diff --git a/jumanji/environments/routing/sokoban/types.py b/jumanji/environments/routing/sokoban/types.py new file mode 100644 index 000000000..eeb561e76 --- /dev/null +++ b/jumanji/environments/routing/sokoban/types.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, NamedTuple + +import chex + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + + +@dataclass +class State: + """ + key: random key used for auto-reset. + fixed_grid: Array (uint8) shape (n_rows, n_cols) array representing the + fixed elements of a sokoban problem. + variable_grid: Array (uint8) shape (n_rows, n_cols) array representing the + variable elements of a sokoban problem. + agent_location: Array (int32) shape (2,) + step_count: Array (int32) shape () + """ + + key: chex.PRNGKey + fixed_grid: chex.Array + variable_grid: chex.Array + agent_location: chex.Array + step_count: chex.Array + + +class Observation(NamedTuple): + """ + The observation returned by the sokoban environment. + grid: Array (uint8) shape (n_rows, n_cols, 2) array representing the + variable and fixed grids. + step_count: Array (int32) shape () the index of the current step. + """ + + grid: chex.Array + step_count: chex.Array diff --git a/jumanji/environments/routing/sokoban/viewer.py b/jumanji/environments/routing/sokoban/viewer.py new file mode 100644 index 000000000..db5716704 --- /dev/null +++ b/jumanji/environments/routing/sokoban/viewer.py @@ -0,0 +1,219 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence, Tuple + +import chex +import matplotlib.animation +import matplotlib.cm +import matplotlib.pyplot as plt +import numpy as np +import pkg_resources +from numpy.typing import NDArray +from PIL import Image + +import jumanji.environments +from jumanji.viewer import Viewer + + +class BoxViewer(Viewer): + FIGURE_SIZE = (10.0, 10.0) + + def __init__( + self, + name: str, + grid_combine: Callable, + ) -> None: + """ + Viewer for a `Sokoban` environment using images from + https://github.com/mpSchrader/gym-sokoban. + + Args: + name: the window name to be used when initialising the window. + grid_combine: function for combining fixed_grid and variable grid + """ + self._name = name + self.NUM_COLORS = 10 + self.grid_combine = grid_combine + self._display = self._display_rgb_array + self._animation: Optional[matplotlib.animation.Animation] = None + + image_names = [ + "floor", + "wall", + "box_target", + "agent", + "box", + "agent_on_target", + "box_on_target", + ] + + def get_image(image_name: str) -> Image.Image: + img_path = pkg_resources.resource_filename( + "jumanji", f"environments/routing/sokoban/imgs/{image_name}.png" + ) + return Image.open(img_path) + + self.images = [get_image(image_name) for image_name in image_names] + + def render(self, state: chex.Array) -> Optional[NDArray]: + """Render the given state of the `Sokoban` environment. + + Args: + state: the environment state to render. + """ + + self._clear_display() + fig, ax = self._get_fig_ax() + ax.clear() + self._add_grid_image(state, ax) + return self._display(fig) + + def animate( + self, + states: Sequence[chex.Array], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to + consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If + it is None, the plot will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = plt.subplots( + num=f"{self._name}Animation", figsize=BoxViewer.FIGURE_SIZE + ) + plt.close(fig) + + def make_frame(state_index: int) -> None: + ax.clear() + state = states[state_index] + self._add_grid_image(state, ax) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self._name) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + """ + Fetch or create a matplotlib figure and its associated axes. + + Returns: + fig: (plt.Figure) A matplotlib figure object + axes: (plt.Axes) The axes associated with the figure. + """ + recreate = not plt.fignum_exists(self._name) + fig = plt.figure(self._name, BoxViewer.FIGURE_SIZE) + if recreate: + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + else: + ax = fig.get_axes()[0] + return fig, ax + + def _add_grid_image(self, state: chex.Array, ax: plt.Axes) -> None: + """ + Add a grid image to the provided axes. + + Args: + state: 'State' object representing a state of Sokoban. + ax: (plt.Axes) object where the state image will be added. + """ + grid = self.grid_combine(state.variable_grid, state.fixed_grid) + + self._draw_grid(grid, ax) + ax.set_axis_off() + ax.set_aspect(1) + ax.relim() + ax.autoscale_view() + + def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: + """ + Draw a grid onto provided axes. + + Args: + grid: Array () of shape (). + ax: (plt.Axes) The axes on which to draw the grid. + """ + + cols, rows = grid.shape + + for col in range(cols): + for row in range(rows): + self._draw_grid_cell(grid[row, col], 9 - row, col, ax) + + def _draw_grid_cell( + self, cell_value: int, row: int, col: int, ax: plt.Axes + ) -> None: + """ + Draw a single cell of the grid. + + Args: + cell_value: int representing the cell's value determining its image. + row: int representing the cell's row index. + col: int representing the cell's col index. + ax: (plt.Axes) The axes on which to draw the cell. + """ + cell_value = int(cell_value) + image = self.images[cell_value] + ax.imshow(image, extent=(col, col + 1, row, row + 1)) + + def _clear_display(self) -> None: + """ + Clear the current notebook display if the environment is a notebook. + """ + + if jumanji.environments.is_notebook(): + import IPython.display + + IPython.display.clear_output(True) + + def _display_rgb_array(self, fig: plt.Figure) -> NDArray: + """ + Convert the given figure to an RGB array. + + Args: + fig: (plt.Figure) The figure to be converted. + + Returns: + NDArray: The RGB array representation of the figure. + """ + fig.canvas.draw() + return np.asarray(fig.canvas.buffer_rgba()) diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 83236f872..0428e646c 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -80,7 +80,7 @@ class TSP(Environment[State]): ```python from jumanji.environments import TSP env = TSP() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec().generate_value() diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index a110c89b8..458d60b07 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sudoku, tetris, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/sokoban.yaml b/jumanji/training/configs/env/sokoban.yaml new file mode 100644 index 000000000..6baf4f2d0 --- /dev/null +++ b/jumanji/training/configs/env/sokoban.yaml @@ -0,0 +1,26 @@ +name: sokoban +registered_version: Sokoban-v0 + +network: + channels: [256,256,512,512] + policy_layers: [64, 64] + value_layers: [128, 128] + +training: + num_epochs: 1000 + num_learner_steps_per_epoch: 500 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 + +a2c: + normalize_advantage: True + discount_factor: 0.97 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 3e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 1fb1df18d..82ad0ae65 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -78,6 +78,10 @@ make_actor_critic_networks_snake, ) from jumanji.training.networks.snake.random import make_random_policy_snake +from jumanji.training.networks.sokoban.actor_critic import ( + make_actor_critic_networks_sokoban, +) +from jumanji.training.networks.sokoban.random import make_random_policy_sokoban from jumanji.training.networks.sudoku.actor_critic import ( make_cnn_actor_critic_networks_sudoku, make_equivariant_actor_critic_networks_sudoku, diff --git a/jumanji/training/networks/sokoban/__init__.py b/jumanji/training/networks/sokoban/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/sokoban/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py new file mode 100644 index 000000000..968180c60 --- /dev/null +++ b/jumanji/training/networks/sokoban/actor_critic.py @@ -0,0 +1,115 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +import chex +import haiku as hk +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban import Observation, Sokoban +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + CategoricalParametricDistribution, +) + + +def make_actor_critic_networks_sokoban( + sokoban: Sokoban, + channels: Sequence[int], + policy_layers: Sequence[int], + value_layers: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `Sokoban` environment.""" + num_actions = sokoban.action_spec().num_values + parametric_action_distribution = CategoricalParametricDistribution( + num_actions=num_actions + ) + + policy_network = make_sokoban_cnn( + num_outputs=num_actions, + mlp_units=policy_layers, + channels=channels, + time_limit=sokoban.time_limit, + ) + value_network = make_sokoban_cnn( + num_outputs=1, + mlp_units=value_layers, + channels=channels, + time_limit=sokoban.time_limit, + ) + + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +def make_sokoban_cnn( + num_outputs: int, + mlp_units: Sequence[int], + channels: Sequence[int], + time_limit: int, +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + + # Iterate over the channels sequence to create convolutional layers + layers = [] + for i, conv_n_channels in enumerate(channels): + layers.append(hk.Conv2D(conv_n_channels, (3, 3), stride=2 if i == 0 else 1)) + layers.append(jax.nn.relu) + + layers.append(hk.Flatten()) + + torso = hk.Sequential(layers) + + x_processed = preprocess_input(observation.grid) + + embedding = torso(x_processed) + + norm_step_count = jnp.expand_dims(observation.step_count / time_limit, axis=-1) + embedding = jnp.concatenate([embedding, norm_step_count], axis=-1) + head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False) + if num_outputs == 1: + value = jnp.squeeze(head(embedding), axis=-1) + return value + else: + logits = head(embedding) + + return logits + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def preprocess_input( + input_array: chex.Array, +) -> chex.Array: + + one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype( + jnp.float32 + ) + + one_hot_array_variable = jnp.equal(input_array[..., 1:2], jnp.array([1, 2])).astype( + jnp.float32 + ) + + total = jnp.concatenate((one_hot_array_fixed, one_hot_array_variable), axis=-1) + + return total diff --git a/jumanji/training/networks/sokoban/random.py b/jumanji/training/networks/sokoban/random.py new file mode 100644 index 000000000..8b428174f --- /dev/null +++ b/jumanji/training/networks/sokoban/random.py @@ -0,0 +1,35 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban import Observation +from jumanji.training.networks.protocols import RandomPolicy + + +def categorical_random( + observation: Observation, + key: chex.PRNGKey, +) -> chex.Array: + logits = jnp.zeros(shape=(observation.grid.shape[0], 4)) + + action = jax.random.categorical(key, logits) + return action + + +def make_random_policy_sokoban() -> RandomPolicy: + """Make random policy for the `Sokoban` environment.""" + return categorical_random diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index dee956345..e2d2b9890 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -40,6 +40,7 @@ RobotWarehouse, RubiksCube, Snake, + Sokoban, Sudoku, Tetris, ) @@ -178,6 +179,9 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "maze": assert isinstance(env.unwrapped, Maze) random_policy = networks.make_random_policy_maze() + elif cfg.env.name == "sokoban": + assert isinstance(env.unwrapped, Sokoban) + random_policy = networks.make_random_policy_sokoban() elif cfg.env.name == "connector": assert isinstance(env.unwrapped, Connector) random_policy = networks.make_random_policy_connector() @@ -326,6 +330,14 @@ def _setup_actor_critic_neworks( # noqa: CCR001 policy_layers=cfg.env.network.policy_layers, value_layers=cfg.env.network.value_layers, ) + elif cfg.env.name == "sokoban": + assert isinstance(env.unwrapped, Sokoban) + actor_critic_networks = networks.make_actor_critic_networks_sokoban( + sokoban=env.unwrapped, + channels=cfg.env.network.channels, + policy_layers=cfg.env.network.policy_layers, + value_layers=cfg.env.network.value_layers, + ) elif cfg.env.name == "cleaner": assert isinstance(env.unwrapped, Cleaner) actor_critic_networks = networks.make_actor_critic_networks_cleaner( diff --git a/mkdocs.yml b/mkdocs.yml index 33c290c91..39dbca1cd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,6 +38,7 @@ nav: - RobotWarehouse: environments/robot_warehouse.md - Snake: environments/snake.md - TSP: environments/tsp.md + - Sokoban: environments/sokoban.md - PacMan: environments/pac_man.md - User Guides: - Advanced Usage: guides/advanced_usage.md @@ -68,6 +69,7 @@ nav: - RobotWarehouse: api/environments/robot_warehouse.md - Snake: api/environments/snake.md - TSP: api/environments/tsp.md + - Sokoban: api/environments/sokoban.md - PacMan: api/environments/pac_man.md - Wrappers: api/wrappers.md - Types: api/types.md diff --git a/pyproject.toml b/pyproject.toml index a6321cfd9..72558473a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,5 +41,6 @@ module = [ "haiku.*", "hydra.*", "omegaconf.*", + "huggingface_hub.*", ] ignore_missing_imports = true diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index df33e2e2b..ea0b437c2 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -24,3 +24,6 @@ pytest-xdist pytype scipy>=1.7.3 testfixtures +types-Pillow +types-requests +types-setuptools