diff --git a/atomai/trainers/trainer.py b/atomai/trainers/trainer.py index b19c74a..2cf9de5 100644 --- a/atomai/trainers/trainer.py +++ b/atomai/trainers/trainer.py @@ -67,8 +67,8 @@ class BaseTrainer: """ def __init__(self): set_train_rng(1) - if torch.backends.mps.is_available(): - self.device = torch.device("mps") # backend for Apple silicon GPUs + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + self.device = torch.device('mps') # backend for Apple silicon GPUs elif torch.cuda.is_available(): self.device = 'cuda' else: @@ -367,8 +367,8 @@ def print_statistics(self, e: int, **kwargs) -> None: accuracy_metrics = "Accuracy" if torch.cuda.is_available(): gpu_usage = gpu_usage_map(torch.cuda.current_device()) - elif torch.backends.mps.is_available(): - gpu_usage = gpu_usage_map(torch.mps.current_device()) + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + gpu_usage = ['N/A ', ' N/A'] else: gpu_usage = ['N/A ', ' N/A'] if self.compute_accuracy: diff --git a/atomai/utils/nn.py b/atomai/utils/nn.py index 6532a57..619d6b5 100644 --- a/atomai/utils/nn.py +++ b/atomai/utils/nn.py @@ -48,7 +48,7 @@ def load_weights(model: Type[torch.nn.Module], torch.manual_seed(0) if torch.cuda.device_count() > 0: checkpoint = torch.load(weights_path) - elif torch.backends.mps.is_available(): + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): checkpoint = torch.load(weights_path, map_location='mps') else: checkpoint = torch.load(weights_path, map_location='cpu')