DawnC commited on
Commit
eadb15b
1 Parent(s): 8a2180c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -39,24 +39,34 @@ import spaces
39
  import torch.cuda.amp
40
 
41
 
42
- @spaces.GPU
43
  def get_device():
 
 
 
 
44
  print("Initializing device configuration...")
45
 
46
  try:
47
- torch.cuda.init()
48
- # 使用 mixed precision
49
- torch.set_float32_matmul_precision('medium')
50
-
51
  if torch.cuda.is_available():
52
  device = torch.device('cuda')
53
- torch.cuda.set_device(device)
54
- print(f"Successfully initialized CUDA device")
 
 
 
 
 
 
55
  return device
56
- except Exception as e:
 
57
  print(f"GPU initialization error: {str(e)}")
58
 
59
- print("Using CPU fallback")
 
 
60
  return torch.device('cpu')
61
 
62
  device = get_device()
@@ -152,50 +162,41 @@ class BaseModel(nn.Module):
152
 
153
  def load_model(model_path, model_instance, device):
154
  """
155
- 優化的模型載入函數,支援 ZeroGPU 和混合精度計算
156
-
157
- Args:
158
- model_path: 模型檔案的路徑
159
- model_instance: BaseModel 的實例
160
- device: 計算設備(CPU 或 GPU)
161
-
162
- Returns:
163
- 載入權重後的模型實例
164
  """
165
  try:
166
- print(f"正在將模型載入到設備: {device}")
167
 
168
- # 使用混合精度計算來優化記憶體使用
169
- with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
170
- # 載入檢查點,使用 weights_only=True 來避免警告
171
- checkpoint = torch.load(
172
- model_path,
173
- map_location=device,
174
- weights_only=True
175
- )
 
 
 
 
 
 
176
 
177
- # 載入模型權重
178
- model_instance.load_state_dict(checkpoint['base_model'], strict=False)
 
 
 
179
 
180
- # 確保模型在正確的設備上
181
- if device.type == 'cuda':
182
- model_instance = model_instance.to(device)
183
-
184
- # 設置為評估模式
185
  model_instance.eval()
186
-
187
- print("模型載入成功")
188
  return model_instance
189
-
190
- except Exception as e:
191
- print(f"模型載入出錯: {str(e)}")
192
- print("嘗試使用基本載入方式...")
193
 
194
- # 如果優化載入失敗,退回到基本載入方式
195
- checkpoint = torch.load(model_path, map_location=device)
196
- model_instance.load_state_dict(checkpoint['base_model'], strict=False)
197
- model_instance.eval()
198
- return model_instance
199
 
200
  # Initialize model
201
  num_classes = len(dog_breeds)
 
39
  import torch.cuda.amp
40
 
41
 
42
+ @spaces.GPU(duration=30) # Request smaller GPU time chunk
43
  def get_device():
44
+ """
45
+ Initialize device configuration with automatic CPU fallback.
46
+ Attempts GPU first, falls back to CPU if necessary.
47
+ """
48
  print("Initializing device configuration...")
49
 
50
  try:
51
+ # Attempt GPU initialization with optimizations
 
 
 
52
  if torch.cuda.is_available():
53
  device = torch.device('cuda')
54
+ torch.cuda.init()
55
+ torch.set_float32_matmul_precision('medium')
56
+
57
+ # Add CUDA optimizations
58
+ torch.backends.cudnn.benchmark = True
59
+ torch.backends.cudnn.deterministic = False
60
+
61
+ print(f"Successfully initialized CUDA device: {torch.cuda.get_device_name(device)}")
62
  return device
63
+
64
+ except (spaces.zero.gradio.HTMLError, RuntimeError) as e:
65
  print(f"GPU initialization error: {str(e)}")
66
 
67
+ # CPU fallback with optimizations
68
+ print("Using CPU mode")
69
+ torch.set_num_threads(4) # Optimize CPU performance
70
  return torch.device('cpu')
71
 
72
  device = get_device()
 
162
 
163
  def load_model(model_path, model_instance, device):
164
  """
165
+ Enhanced model loading function with device handling.
166
+ Maintains original function signature for compatibility.
 
 
 
 
 
 
 
167
  """
168
  try:
169
+ print(f"Loading model to device: {device}")
170
 
171
+ # Load checkpoint with optimizations
172
+ checkpoint = torch.load(
173
+ model_path,
174
+ map_location=device,
175
+ weights_only=True
176
+ )
177
+
178
+ # Load model weights
179
+ model_instance.load_state_dict(checkpoint['base_model'], strict=False)
180
+ model_instance = model_instance.to(device)
181
+ model_instance.eval()
182
+
183
+ print("Model loading successful")
184
+ return model_instance
185
 
186
+ except RuntimeError as e:
187
+ if "CUDA out of memory" in str(e):
188
+ print("GPU memory exceeded, falling back to CPU")
189
+ device = torch.device('cpu')
190
+ model_instance = model_instance.cpu()
191
 
192
+ # Retry loading on CPU
193
+ checkpoint = torch.load(model_path, map_location='cpu')
194
+ model_instance.load_state_dict(checkpoint['base_model'], strict=False)
 
 
195
  model_instance.eval()
 
 
196
  return model_instance
 
 
 
 
197
 
198
+ print(f"Model loading error: {str(e)}")
199
+ raise
 
 
 
200
 
201
  # Initialize model
202
  num_classes = len(dog_breeds)