Skip to content

Commit

Permalink
Switch to weighted metrics in the compile step of the supervised model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654919610
  • Loading branch information
raj-sinha authored and The spade_anomaly_detection Authors committed Jul 22, 2024
1 parent 1d4e648 commit bce9591
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions spade_anomaly_detection/supervised_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@ def save(self, save_location: str) -> None:
save_location: String denoting a Google Cloud Storage location, or local
disk path. Note that local assets will be deleted when the VM running
this container is shutdown at the end of the training job.
Raises:
ValueError: If the supervised model was not initialized.
"""
if self.supervised_model is not None:
self.supervised_model.save(save_location)
logging.info('Saved model assets to %s', save_location)
if self.supervised_model is None:
raise ValueError('Supervised model was not initialized.')
else:
logging.warning('No model to save.')
self.supervised_model.save(save_location) # pytype: disable=attribute-error

logging.info('Saved model assets to %s', save_location)


@dataclasses.dataclass
Expand Down Expand Up @@ -133,7 +137,7 @@ def __init__(
**dataclasses.asdict(self.supervised_parameters)
)
self.supervised_model.compile(
metrics=[
weighted_metrics=[
tf.keras.metrics.AUC(name='Supervised_Model_AUC'),
tf.keras.metrics.Precision(
thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
Expand Down

0 comments on commit bce9591

Please sign in to comment.