Skip to content

This is a simple PyTorch implementation of Vision Transformer (ViT) described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

Notifications You must be signed in to change notification settings

godofpdog/ViT_PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

34 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ViT_PyTorch

This is a simple PyTorch implementation of Vision Transformer (ViT) described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

image

Usage

Download pre-trained weights.

python convert.py jax/weights/path/ converted/weights/path --model_name B_16_384

Train ViT

You can train ViT on your own dataset by following command :

python train.py /train/data/folder/ --valid_dir /validation/data/folder/ --pretrained_weights /pre-trained/weights/path 

Check out the Colab for preparing data, fine-tuning the model, and inference.

Details about training arguments :

Argument Description Type Default
train_dir Directory of training data. str required argument
valid_dir Directory of validation data. str None
valid_rate Proportion of validation sample split from training data. float None
output_dir Directory of output results where trained weights and training history will be stored. str None
model_config Modle arch configuration. (config path or arch name, e.g. "B_16_384") str B_16_384
pretrained_weights Filename of pre-trained weights. Train from scratch if 'None'. str None
freeze_extractor If True, freeze the feature extractor weights to fine-tune the classification head. bool True
batch_size Batch size. int 64
init_lr Initial learning rate. float 1e-3
weight_decay Weight decay (L2 penalty). float 1e-5
beta1 Adam 'betas' param 1. float 0.9
beta2 Adam 'betas' param 2. float 0.999
max_epoch Maximun training epochs. int 100
patient Improved patient for early stopping. int None
monitor Metric to be monitored. ('loss' or 'acc') str loss
min_delta Minimum change in the monitored metric to qualify as an improvement. float 0.0
save_best Whether to save weights from the epoch with the best monitored metric. bool True
warmup Warmup epochs. int 0
scheduler Training scheduler. ('cosine', 'step' or 'exp') str None
t_max Maximum number of iterations. (cosine scheduler) int 10
eta_min Minimum learning rate. (cosine scheduler) float 0.0
step_size Period of learning rate decay. (step scheduler) int 10
gamma Multiplicative factor of learning rate decay. (step/exp scheduler) float 0.1
image_size Input image size. int 384
crop_margin Margin for random cropping. int 32
horizontal_flip Horizontal flip probability. float 0.5
rotation Degree for random rotation. float 10.
device Computation device. ('cuda' or 'cpu') str cuda
random_seed Random seed in this repo. int 427

About

This is a simple PyTorch implementation of Vision Transformer (ViT) described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages