Skip to content

Commit

Permalink
Merge pull request #112 from JetBrains-Research/reset-metrics
Browse files Browse the repository at this point in the history
Reset metrics after each epoch
  • Loading branch information
SpirinEgor committed Sep 25, 2021
2 parents be05daf + f1a11b8 commit 1dd7abe
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions code2seq/model/code2class.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str):
mean_loss = torch.stack([out[f"{step}/loss"] for out in outputs]).mean()
accuracy = self.__metrics[f"{step}_acc"].compute()
log = {f"{step}/loss": mean_loss, f"{step}/accuracy": accuracy}
self.__metrics[f"{step}_acc"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, outputs: EPOCH_OUTPUT):
Expand Down
1 change: 1 addition & 0 deletions code2seq/model/code2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
self.__metrics[f"{step}_f1"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "1.0.2"
VERSION = "1.0.3"

with open("README.md") as readme_file:
readme = readme_file.read()
Expand Down

0 comments on commit 1dd7abe

Please sign in to comment.