Skip to content

Latest commit

 

History

History
63 lines (47 loc) · 2.66 KB

README.md

File metadata and controls

63 lines (47 loc) · 2.66 KB

Contrastive Unpaired Translation (CUT)

This is an implementation of Contrastive Learning for Unpaired Image-to-Image Translation in Tensorflow 2.

Contrastive Unpaired Translation(CUT) using a framework based on contrastive learning, the goal is to associate the input and output patches, "query" refers to an output patch, positive and negatives are corresponding and noncorresponding input patches. Compared to CycleGAN, CUT enables one-sided translation, while improving quality and reducing training time.

Translated examples of summer2winter

Training

Use train.py to train a CUT/FastCUT model on given dataset. Training takes 340ms(CUDA ops)/400ms(Tensorflow ops) for a singel step on GTX 1080ti.

Example usage for training on horse2zebra-dataset:

python train.py --mode cut                                    \
                --save_n_epoch 10                             \
                --train_src_dir ./datasets/horse2zebra/trainA \
                --train_tar_dir ./datasets/horse2zebra/trainB \
                --test_src_dir ./datasets/horse2zebra/testA   \
                --test_tar_dir ./datasets/horse2zebra/testB   \

Inference

Use inference.py to translate image from source domain to target domain. The pre-trained weights are located here.

Example usage:

python inference.py --mode cut                            \
                    --weights ./output/checkpoints        \
                    --input ./datasets/horse2zebra/testA  \

Qualitative comparisons between the implementation and the results from the paper.

Requirements

You will need the following to run the above:

  • TensorFlow >= 2.0
  • Python 3, Numpy 1.18, Matplotlib 3.3.1
  • If you want to use custom TensorFlow ops:

Acknowledgements