Skip to content

The MegEngine Implementation of MAE(Masked Auto Encoder).

Notifications You must be signed in to change notification settings

Asthestarsfalll/MAE-MegEngine

Repository files navigation

MAE-MegEngine

The MegEngine Implementation of MAE(Masked Auto Encoder).

Usage

Make sure you are using a GPU device, for there is a gap between output of GPU and CPU device in MegEngine gather API

Install Dependencies

pip install -r requirements.txt

If you don't want to compare the ouput error between the MegEngine implementation and PyTorch one, just ignore requirements.txt and install MegEngine from the command line:

python3 -m pip install --upgrade pip 
python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html

Note:

The pytorch implementation is based on timm==0.3.2, for which a fix is needed to work with PyTorch 1.8.1+.

Convert Weights

Convert trained weights from torch to megengine, the converted weights will be save in ./pretained/ , you need to specify the converte model architecture and path to checkpoint offered by official repo.

pre-trained checkpoint:

ViT-Base ViT-Large ViT-Huge
download download download

visuialize checkpoint:

ViT-Base ViT-Large ViT-Large-GanLoss ViT-Huge
download download download download
python convert_weights.py -m mae_vit_base_patch16 -c /local/path/to/ckpt

Compare

Use python compare.py .

By default, the compare script will convert the torch state_dict to the format that megengine need.

If you want to compare the error by checkpoints, you neet load them manually.

Visualize

Just read and run visualize.py.

Load From Hub

Import from megengine.hub:

Way 1:

from functools import partial
import megengine.module as M
from megengine import hub

modelhub = hub.import_module(
    repo_info='asthestarsfalll/MAE-MegEngine:main', git_host='github.com')

# load VAN model and custom on you own
model = modelhub.MAE(
    patch_size=16, embed_dim=768, depth=12, num_heads=12,
    decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
    mlp_ratio=4, norm_layer=partial(M.LayerNorm, eps=1e-6))

# load pretrained model
pretrained_model = modelhub.mae_vit_base_patch16(pretrained=True)

Way 2:

from  megengine import hub

# load pretrained model 
model_name = 'mae_vit_base_patch16'
pretrained_model = hub.load(
    repo_info='asthestarsfalll/MAE-MegEngine:main', entry=model_name, git_host='github.com', pretrained=True)

Currently pretrained model only support mae_vit_base_patch16.

But you can still load the model without pretrained weights like this:

model = modelhub.mae_vit_large_patch16()
# or
model_name = 'mae_vit_large_patch16'
model = hub.load(
    repo_info='asthestarsfalll/MAE-MegEngine:main', entry=model_name, git_host='github.com')

TODO

  • Add interfaces of visialize.
  • Some down stream tasks maybe.
  • Some introduction about MAE.

Reference

The official implementation of MAE

About

The MegEngine Implementation of MAE(Masked Auto Encoder).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages