Skip to content

Commit

Permalink
Add argument parser
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanosM97 committed Aug 18, 2022
1 parent e071542 commit a863380
Showing 1 changed file with 155 additions and 3 deletions.
158 changes: 155 additions & 3 deletions generation/distylegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
(https://github.com/NVlabs/stylegan2-ada-pytorch), for the task of
conditional image generation on CIFAR-10.
"""
import argparse
import json
import random
from datetime import datetime
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
transform: Callable = None,
num_test: int = 30,
device: str = None,
**kwargs
) -> None:
""" Initialize the DiStyleGAN model.
Expand Down Expand Up @@ -400,7 +402,8 @@ def train(
lr_decay: int = 350000,
checkpoint_interval: int = 20,
checkpoint_path: str = None,
num_workers: int = 0
num_workers: int = 0,
**kwargs
):
"""Train DiStyleGAN.
Expand Down Expand Up @@ -568,7 +571,8 @@ def generate(
nsamples: int,
label: "int | list[int]" = None,
save: str = None,
batch_size: int = 32
batch_size: int = 32,
**kwargs
) -> torch.Tensor:
"""Generate images using a pre-trained model's checkpoint.
Expand All @@ -578,7 +582,8 @@ def generate(
- nsamples (int) : number of samples to generate
- label (int, list[int], optional) : class label for the samples
(Default: None, random labels)
- save (str) : path to save the generated images (Default: None)
- save (str, optional) : path to save the generated images
(Default: None)
- batch_size (int, optional) : number of samples per batch
(Default: 32)
"""
Expand Down Expand Up @@ -643,3 +648,150 @@ def generate(
all_images = torch.stack(all_images)

return all_images


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Train DiStyleGAN or generate images
using a pretrained model.""")
subparsers = parser.add_subparsers(help="commands", dest="command")

# Create the parser for the "train" command
parser_train = subparsers.add_parser('train', help="""Train DiStyleGAN from
scratch, or use a checkpoint""")
train_required = parser_train.add_argument_group(
'Required arguments for the training procedure')
train_optional_network = parser_train.add_argument_group(
'Optional arguments about the network configuration')
train_optional = parser_train.add_argument_group(
'Optional arguments about the training procedure')

train_required.add_argument(
'--dataset', type=str, required=True,
help="""Path to the dataset directory of the fake CIFAR10 data
generated by the teacher network""")
train_required.add_argument(
'--save', type=str, required=True,
help="""Path to save checkpoints and results""")

train_optional_network.add_argument(
'--c_dim', type=int, default=10,
help='Condition dimension (Default: 10)')
train_optional_network.add_argument(
'--lambda_ganD', type=float, default=0.2,
help="""Weight for the adversarial GAN loss of the Discriminator
(Default: 0.2)""")
train_optional_network.add_argument(
'--lambda_ganG', type=float, default=0.01,
help="""Weight for the adversarial distillation loss of the Generator
(Default: 0.01)""")
train_optional_network.add_argument(
'--lambda_pixel', type=float, default=0.2,
help='Weight for the pixel loss of the Generator (Default: 0.2)')
train_optional_network.add_argument(
'--nc', type=int, default=3,
help="""Number of channels for the images (Default: 3)""")
train_optional_network.add_argument(
'--ndf', type=int, default=128,
help="""Number of discriminator filters in the first convolutional
layer (Default: 128)""")
train_optional_network.add_argument(
'--ngf', type=int, default=256,
help="""Number of generator filters in the first convolutional layer
(Default: 256)""")
train_optional_network.add_argument(
'--project_dim', type=int, default=128,
help="""Dimension to project the input condition (Default: 128)""")
train_optional_network.add_argument(
'--transform', type=callable, default=None,
help="""Optional transform to be applied on a sample image
(Default: None)""")
train_optional_network.add_argument(
'--z_dim', type=int, default=512,
help='Noise dimension (Default: 512)')

train_optional.add_argument(
'--adam_momentum', type=float, default=0.5,
help="""Momentum value for the Adam optimizers' betas
(Default: 0.5)""")
train_optional.add_argument(
'--batch_size', type=int, default=128,
help="""Number of samples per batch (Default: 128)""")
train_optional.add_argument(
'--checkpoint_interval', type=int, default=20,
help="""Checkpoints will be saved every `checkpoint_interval` epochs
(Default: 20)""")
train_optional.add_argument('--checkpoint_path', type=str, default=None,
help="""Path to previous checkpoint""")
train_optional.add_argument(
'--device', type=str, default=None,
help="""Device to use for training ('cpu' or 'cuda') (Default: If there
is a CUDA device available, it will be used for training)""")
train_optional.add_argument(
'--epochs', type=int, default=150,
help="""Number of training epochs (Default: 150)""")
train_optional.add_argument(
'--gstep', type=int, default=10,
help="""The number of discriminator updates after which the generator
is updated using the full loss (Default: 10)""")
train_optional.add_argument(
'--lr_D', type=float, default=0.0002,
help="""Learning rate for the discriminator's Adam optimizer
(Default: 0.0002)""")
train_optional.add_argument(
'--lr_G', type=float, default=0.0002,
help="""Learning rate for the generator's Adam optimizer
(Default: 0.0002)""")
train_optional.add_argument(
'--lr_decay', type=int, default=350000,
help="""Iteration to start decaying the learning rates for the
Generator and the Discriminator (Default: 350000) """)
train_optional.add_argument(
'--num_test', type=int, default=30,
help="""Number of generated images for evaluation (Default: 30)""")
train_optional.add_argument(
'--num_workers', type=int, default=0,
help="""number of subprocesses to use for data loading (Default: 0,
whichs means that the data will be loaded in the main process.)""")
train_optional.add_argument(
'--real_dataset', type=str, default=None,
help="""Path to the dataset directory of the real CIFAR10 data.
(Default: None, it will be downloaded and saved in the parent
directory of input `dataset` path)""")

# Create the parser for the "generate" command
parser_generate = subparsers.add_parser(
'generate', help="""Generate images using a
pretrained DiStyleGAN model""")
generate_required = parser_generate.add_argument_group(
'Required arguments for the generation procedure')
generate_optional = parser_generate.add_argument_group(
'Optional arguments about the generation procedure')

generate_required.add_argument(
'--checkpoint_path', type=str, required=True,
help="""Path to previous checkpoint (the directory must contain the
generator.pt and config.json files)""")
generate_required.add_argument(
'--nsamples', type=int, required=True,
help="""Number of samples to generate per label""")
generate_required.add_argument(
'--save', type=str, required=True,
help="""Path to save the generated images to""")

generate_optional.add_argument(
'--label', nargs="*", default=None, type=int, choices=range(0, 10),
help="""Class label(s) for the samples
(Default: None, random labels) --> e.g. --label 0 3 7""")
generate_optional.add_argument(
'--batch_size', type=int, default=32,
help="""Number of samples per batch (Default: 32)""")

# Parse arguments
args = vars(parser.parse_args())

model = DiStyleGAN(**args)
if args["command"] == "train":
model.train(**args)
else:
_ = model.generate(**args)

0 comments on commit a863380

Please sign in to comment.