Skip to content

Commit

Permalink
Update to new version of commode-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
SpirinEgor committed Sep 20, 2021
1 parent 84a8d01 commit be05daf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
8 changes: 3 additions & 5 deletions code2seq/utils/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os.path import basename

import torch
from commode_utils.callback import UploadCheckpointCallback, PrintEpochResultCallback
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
from omegaconf import DictConfig
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger


Expand All @@ -18,15 +18,14 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
wandb_logger = WandbLogger(project=f"{model_name} -- {dataset_name}", log_model=False, offline=config.log_offline)

# define model checkpoint callback
checkpoint_callback = ModelCheckpoint(
checkpoint_callback = ModelCheckpointWithUpload(
dirpath=wandb_logger.experiment.dir,
filename="{epoch:02d}-val_loss={val/loss:.4f}",
monitor="val/loss",
every_n_epochs=params.save_every_epoch,
save_top_k=-1,
auto_insert_metric_name=False,
)
upload_checkpoint_callback = UploadCheckpointCallback(wandb_logger.experiment.dir)
# define early stopping callback
early_stopping_callback = EarlyStopping(patience=params.patience, monitor="val/loss", verbose=True, mode="min")
# define callback for printing intermediate result
Expand All @@ -48,7 +47,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
lr_logger,
early_stopping_callback,
checkpoint_callback,
upload_checkpoint_callback,
print_epoch_result_callback,
],
resume_from_checkpoint=config.get("checkpoint", None),
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ black==21.7b0
mypy==0.910

torch==1.9.0
pytorch-lightning==1.4.2
torchmetrics==0.5.0
pytorch-lightning==1.4.7
torchmetrics==0.5.1

tqdm==4.62.1
wandb==0.12.0
tqdm==4.62.2
wandb==0.12.2
omegaconf==2.1.1
commode-utils==0.3.8
commode-utils==0.3.9
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "1.0.1"
VERSION = "1.0.2"

with open("README.md") as readme_file:
readme = readme_file.read()
Expand Down

0 comments on commit be05daf

Please sign in to comment.