Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -39,24 +39,34 @@ import spaces
|
|
39 |
import torch.cuda.amp
|
40 |
|
41 |
|
42 |
-
@spaces.GPU
|
43 |
def get_device():
|
|
|
|
|
|
|
|
|
44 |
print("Initializing device configuration...")
|
45 |
|
46 |
try:
|
47 |
-
|
48 |
-
# 使用 mixed precision
|
49 |
-
torch.set_float32_matmul_precision('medium')
|
50 |
-
|
51 |
if torch.cuda.is_available():
|
52 |
device = torch.device('cuda')
|
53 |
-
torch.cuda.
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return device
|
56 |
-
|
|
|
57 |
print(f"GPU initialization error: {str(e)}")
|
58 |
|
59 |
-
|
|
|
|
|
60 |
return torch.device('cpu')
|
61 |
|
62 |
device = get_device()
|
@@ -152,50 +162,41 @@ class BaseModel(nn.Module):
|
|
152 |
|
153 |
def load_model(model_path, model_instance, device):
|
154 |
"""
|
155 |
-
|
156 |
-
|
157 |
-
Args:
|
158 |
-
model_path: 模型檔案的路徑
|
159 |
-
model_instance: BaseModel 的實例
|
160 |
-
device: 計算設備(CPU 或 GPU)
|
161 |
-
|
162 |
-
Returns:
|
163 |
-
載入權重後的模型實例
|
164 |
"""
|
165 |
try:
|
166 |
-
print(f"
|
167 |
|
168 |
-
#
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
179 |
|
180 |
-
#
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# 設置為評估模式
|
185 |
model_instance.eval()
|
186 |
-
|
187 |
-
print("模型載入成功")
|
188 |
return model_instance
|
189 |
-
|
190 |
-
except Exception as e:
|
191 |
-
print(f"模型載入出錯: {str(e)}")
|
192 |
-
print("嘗試使用基本載入方式...")
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
197 |
-
model_instance.eval()
|
198 |
-
return model_instance
|
199 |
|
200 |
# Initialize model
|
201 |
num_classes = len(dog_breeds)
|
|
|
39 |
import torch.cuda.amp
|
40 |
|
41 |
|
42 |
+
@spaces.GPU(duration=30) # Request smaller GPU time chunk
|
43 |
def get_device():
|
44 |
+
"""
|
45 |
+
Initialize device configuration with automatic CPU fallback.
|
46 |
+
Attempts GPU first, falls back to CPU if necessary.
|
47 |
+
"""
|
48 |
print("Initializing device configuration...")
|
49 |
|
50 |
try:
|
51 |
+
# Attempt GPU initialization with optimizations
|
|
|
|
|
|
|
52 |
if torch.cuda.is_available():
|
53 |
device = torch.device('cuda')
|
54 |
+
torch.cuda.init()
|
55 |
+
torch.set_float32_matmul_precision('medium')
|
56 |
+
|
57 |
+
# Add CUDA optimizations
|
58 |
+
torch.backends.cudnn.benchmark = True
|
59 |
+
torch.backends.cudnn.deterministic = False
|
60 |
+
|
61 |
+
print(f"Successfully initialized CUDA device: {torch.cuda.get_device_name(device)}")
|
62 |
return device
|
63 |
+
|
64 |
+
except (spaces.zero.gradio.HTMLError, RuntimeError) as e:
|
65 |
print(f"GPU initialization error: {str(e)}")
|
66 |
|
67 |
+
# CPU fallback with optimizations
|
68 |
+
print("Using CPU mode")
|
69 |
+
torch.set_num_threads(4) # Optimize CPU performance
|
70 |
return torch.device('cpu')
|
71 |
|
72 |
device = get_device()
|
|
|
162 |
|
163 |
def load_model(model_path, model_instance, device):
|
164 |
"""
|
165 |
+
Enhanced model loading function with device handling.
|
166 |
+
Maintains original function signature for compatibility.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
"""
|
168 |
try:
|
169 |
+
print(f"Loading model to device: {device}")
|
170 |
|
171 |
+
# Load checkpoint with optimizations
|
172 |
+
checkpoint = torch.load(
|
173 |
+
model_path,
|
174 |
+
map_location=device,
|
175 |
+
weights_only=True
|
176 |
+
)
|
177 |
+
|
178 |
+
# Load model weights
|
179 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
180 |
+
model_instance = model_instance.to(device)
|
181 |
+
model_instance.eval()
|
182 |
+
|
183 |
+
print("Model loading successful")
|
184 |
+
return model_instance
|
185 |
|
186 |
+
except RuntimeError as e:
|
187 |
+
if "CUDA out of memory" in str(e):
|
188 |
+
print("GPU memory exceeded, falling back to CPU")
|
189 |
+
device = torch.device('cpu')
|
190 |
+
model_instance = model_instance.cpu()
|
191 |
|
192 |
+
# Retry loading on CPU
|
193 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
194 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
|
|
|
|
195 |
model_instance.eval()
|
|
|
|
|
196 |
return model_instance
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
print(f"Model loading error: {str(e)}")
|
199 |
+
raise
|
|
|
|
|
|
|
200 |
|
201 |
# Initialize model
|
202 |
num_classes = len(dog_breeds)
|