DawnC commited on
Commit
7e0e5aa
1 Parent(s): 104c504

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +43 -32
device_manager.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import os
3
  import logging
 
 
4
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
@@ -17,40 +19,49 @@ class DeviceManager:
17
  def __init__(self):
18
  if self._initialized:
19
  return
20
-
21
  self._initialized = True
22
- self._current_device = None
23
- self.initialize_device()
24
-
25
- def initialize_device(self):
26
  try:
27
- if os.environ.get('SPACE_ID'):
28
- # 嘗試初始化 CUDA 設備
29
- if torch.cuda.is_available():
30
- self._current_device = torch.device('cuda')
31
- # 設置 CUDA 設備為可見
32
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
33
- logger.info("CUDA device initialized successfully")
34
- else:
35
- raise RuntimeError("CUDA not available")
36
- else:
37
- raise RuntimeError("Not in Spaces environment")
38
  except Exception as e:
39
- logger.warning(f"Using CPU due to: {e}")
40
- self._current_device = torch.device('cpu')
 
 
41
 
42
- def get_optimal_device(self):
43
- if self._current_device is None:
44
- self.initialize_device()
45
- return self._current_device
46
-
47
- def to_device(tensor_or_model, device=None):
48
- """Helper function to move tensors or models to the appropriate device"""
49
- if device is None:
50
- device = DeviceManager().get_optimal_device()
51
 
52
- try:
53
- return tensor_or_model.to(device)
54
- except Exception as e:
55
- logger.warning(f"Failed to move to {device}, using CPU: {e}")
56
- return tensor_or_model.to('cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import os
3
  import logging
4
+ import spaces
5
+ from functools import wraps
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
 
19
  def __init__(self):
20
  if self._initialized:
21
  return
 
22
  self._initialized = True
23
+ self.device = self._initialize_device()
24
+
25
+ def _initialize_device(self):
26
+ """初始化並確定使用的設備"""
27
  try:
28
+ # 檢查是否在 Spaces 環境且有 ZeroGPU
29
+ if os.environ.get('SPACE_ID') and torch.cuda.is_available():
30
+ logger.info("ZeroGPU environment detected")
31
+ return 'cuda'
 
 
 
 
 
 
 
32
  except Exception as e:
33
+ logger.warning(f"Unable to initialize ZeroGPU: {e}")
34
+
35
+ logger.info("Using CPU")
36
+ return 'cpu'
37
 
38
+ def get_device(self):
39
+ """獲取當前設備"""
40
+ return self.device
 
 
 
 
 
 
41
 
42
+ def to_device(self, model_or_tensor):
43
+ """將模型或張量移到正確的設備上"""
44
+ try:
45
+ if hasattr(model_or_tensor, 'to'):
46
+ return model_or_tensor.to(self.device)
47
+ except Exception as e:
48
+ logger.warning(f"Failed to move to {self.device}, using CPU: {e}")
49
+ self.device = 'cpu'
50
+ return model_or_tensor.to('cpu')
51
+ return model_or_tensor
52
+
53
+ def adaptive_gpu(duration=60):
54
+ """結合 spaces.GPU 和 CPU 降級的裝飾器"""
55
+ def decorator(func):
56
+ @wraps(func)
57
+ async def wrapper(*args, **kwargs):
58
+ device_mgr = DeviceManager()
59
+ if device_mgr.get_device() == 'cuda':
60
+ # 在 ZeroGPU 環境中使用 spaces.GPU
61
+ decorated = spaces.GPU(duration=duration)(func)
62
+ return await decorated(*args, **kwargs)
63
+ else:
64
+ # 在 CPU 環境中直接執行
65
+ return await func(*args, **kwargs)
66
+ return wrapper
67
+ return decorator