Skip to content

Repository containing the code for the paper "Explicitly Disentangled Representations in Object-Centric Learning".

Notifications You must be signed in to change notification settings

riccardomajellaro/disentangled-slot-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Disentangled Slot Attention (DISA)

Repository containing the code for the paper "Explicitly Disentangled Representations in Object-Centric Learning".

DISA-Architecture



Prepare datasets

Use the commands below to download and read the TFRecord of a dataset. Replace RECORD_PATH with the path to the directory where you want to download the TFRecord, and DATA_PATH with the path to the directory where you want to store the dataset.

Tetrominoes

wget https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords -P RECORD_PATH
python data/tetrominoes.py --tfrecord_path RECORD_PATH --data_path DATA_PATH

Multi-dSprites

wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_colored.tfrecords -P RECORD_PATH
python data/multidsprites.py --tfrecord_path RECORD_PATH --data_path DATA_PATH

CLEVR

wget https://storage.googleapis.com/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords -P RECORD_PATH
python data/clevr.py --tfrecord_path RECORD_PATH --data_path DATA_PATH

CLEVR6

To filter CLEVR into CLEVR6 (maximum of 6 objects + the background), replace CLEVR_PATH with the path to the directory where the CLEVR dataset is already stored, and replace CLEVR6_PATH with the path to the directory where you want to store the filtered CLEVR6 dataset.

python data/clevr6.py --clevr_path CLEVR_PATH --clevr6_path CLEVR6_PATH

Train

Before running the commands below, move to the DISA directory. In addition, set the key data_path of the configurations in configs/objdisc_configs.json and configs/proppred_configs.json to the paths to the related datasets.

Object discovery

Use the command below to train a model on the object discovery task. Replace CONFIG_NAME with the name of the desired configuration (configs/objdisc_configs.json) to run.

python -m training.train --config CONFIG_NAME

If you need to use Distributed Data Parallel (DDP), replace training.train with training.train_dist and add the key num_gpus set to the desired integer in the configuration file.

Property prediction

Use the command below to run the property prediction task. Replace CONFIG_NAME with the name of the desired configuration (configs/proppred_configs.json) to run.

python -m training.prop_pred --config CONFIG_NAME

Evaluate

Before running the commands below, move to the DISA directory.

Object discovery

Use the command below to evaluate a model on the object discovery task (BG-ARI, FG-ARI, MSE). Replace CONFIG_NAME with the name of the desired configuration (configs/objdisc_configs.json) to run, and CKPT_NAME with the name of the checkpoint to load (without .ckpt at the end).

python -m evaluation.obj_disc --config CONFIG_NAME --init_ckpt --CKPT_NAME

If you evaluate a model that samples initial slots and/or position/scale embeddings, add the key reps and set it to, e.g., 10 in order to evaluate each image in the test set 10 times.

Property prediction

Use the command below to evaluate a configuration on the property prediction task. Replace CONFIG_NAME with the name of the desired configuration (configs/proppred_configs.json) to run.

python -m evaluation.prop_pred --config CONFIG_NAME

About

Repository containing the code for the paper "Explicitly Disentangled Representations in Object-Centric Learning".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published