This repository contains a PyTorch implementation of the MultiObject Network (MONet). MONet is a model trained to explain a scene in a fixed number of steps, which allows it to reconstruct objects separately, even when they are occluded by other objects:
Create a conda environment with all the requirements (edit environment.yml if you want to change the name of the environment):
conda env create -f environment.ymlActivate the environment
source activate pytorchWe use Sacred to log the experiments and also as a command line interface. To generate the sprites dataset, from the data folder run
python data.py generate_sprites_multiWith the default options, the training script trains MONet with 5 slots, using a VAE with a latent dimension of 10. Training takes around 4 hours on GPU:
python train.pyAlso check out the notebooks folder for examples with pretrained models.