import os import torch # from utils.utils import get_default_device def get_default_device(): if torch.cuda.is_available(): return torch.device('cuda') elif torch.backends.mps.is_available(): # Not all operations implemented in MPS yet use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1" if use_mps: return torch.device('mps') else: return torch.device('cpu') else: return torch.device('cpu') device = get_default_device() print(f"DiffDock Device: {device}")