Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -538,46 +538,34 @@ from urllib.parse import quote
|
|
538 |
from ultralytics import YOLO
|
539 |
import asyncio
|
540 |
import traceback
|
|
|
|
|
541 |
|
542 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
543 |
-
os.environ['HF_ZERO_GPU'] = '1' # 明確告訴系統我們要使用 ZeroGPU
|
544 |
-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
545 |
|
|
|
546 |
def get_device():
|
547 |
print("Initializing device configuration...")
|
548 |
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
time.sleep(2)
|
557 |
-
|
558 |
device = torch.device('cuda')
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
_ = test_tensor * test_tensor
|
563 |
-
|
564 |
-
print("ZeroGPU initialization successful")
|
565 |
-
print(f"Using device: {device}")
|
566 |
-
if torch.cuda.is_available():
|
567 |
-
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
568 |
-
|
569 |
return device
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
else:
|
576 |
-
if not torch.cuda.is_available():
|
577 |
-
print("CUDA not available, using CPU")
|
578 |
-
elif 'HF_ZERO_GPU' not in os.environ:
|
579 |
-
print("HF_ZERO_GPU not set, using CPU")
|
580 |
-
return torch.device('cpu')
|
581 |
|
582 |
device = get_device()
|
583 |
|
@@ -670,18 +658,60 @@ class BaseModel(nn.Module):
|
|
670 |
logits = self.classifier(attended_features)
|
671 |
return logits, attended_features
|
672 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
# Initialize model
|
674 |
num_classes = len(dog_breeds)
|
675 |
|
676 |
-
# Initialize base model
|
677 |
model = BaseModel(num_classes=num_classes, device=device)
|
678 |
|
679 |
-
#
|
680 |
-
|
681 |
-
checkpoint = torch.load(model_path, map_location=device)
|
682 |
-
|
683 |
-
# Load model state
|
684 |
-
model.load_state_dict(checkpoint['base_model'], strict=False)
|
685 |
model.eval()
|
686 |
|
687 |
# Image preprocessing function
|
|
|
538 |
from ultralytics import YOLO
|
539 |
import asyncio
|
540 |
import traceback
|
541 |
+
import spaces
|
542 |
+
import torch.cuda.amp
|
543 |
|
544 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
545 |
+
# os.environ['HF_ZERO_GPU'] = '1' # 明確告訴系統我們要使用 ZeroGPU
|
546 |
+
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
547 |
|
548 |
+
@spaces.GPU
|
549 |
def get_device():
|
550 |
print("Initializing device configuration...")
|
551 |
|
552 |
+
try:
|
553 |
+
# 強制進行 CUDA 初始化
|
554 |
+
torch.cuda.init()
|
555 |
+
# 使用 mixed precision
|
556 |
+
torch.set_float32_matmul_precision('medium')
|
557 |
+
|
558 |
+
if torch.cuda.is_available():
|
|
|
|
|
559 |
device = torch.device('cuda')
|
560 |
+
# 設置默認的 CUDA 設備
|
561 |
+
torch.cuda.set_device(device)
|
562 |
+
print(f"Successfully initialized CUDA device")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
return device
|
564 |
+
except Exception as e:
|
565 |
+
print(f"GPU initialization error: {str(e)}")
|
566 |
+
|
567 |
+
print("Using CPU fallback")
|
568 |
+
return torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
device = get_device()
|
571 |
|
|
|
658 |
logits = self.classifier(attended_features)
|
659 |
return logits, attended_features
|
660 |
|
661 |
+
def load_model(model_path, model_instance, device):
|
662 |
+
"""
|
663 |
+
優化的模型載入函數,支援 ZeroGPU 和混合精度計算
|
664 |
+
|
665 |
+
Args:
|
666 |
+
model_path: 模型檔案的路徑
|
667 |
+
model_instance: BaseModel 的實例
|
668 |
+
device: 計算設備(CPU 或 GPU)
|
669 |
+
|
670 |
+
Returns:
|
671 |
+
載入權重後的模型實例
|
672 |
+
"""
|
673 |
+
try:
|
674 |
+
print(f"正在將模型載入到設備: {device}")
|
675 |
+
|
676 |
+
# 使用混合精度計算來優化記憶體使用
|
677 |
+
with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
|
678 |
+
# 載入檢查點,使用 weights_only=True 來避免警告
|
679 |
+
checkpoint = torch.load(
|
680 |
+
model_path,
|
681 |
+
map_location=device,
|
682 |
+
weights_only=True
|
683 |
+
)
|
684 |
+
|
685 |
+
# 載入模型權重
|
686 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
687 |
+
|
688 |
+
# 確保模型在正確的設備上
|
689 |
+
if device.type == 'cuda':
|
690 |
+
model_instance = model_instance.to(device)
|
691 |
+
|
692 |
+
# 設置為評估模式
|
693 |
+
model_instance.eval()
|
694 |
+
|
695 |
+
print("模型載入成功")
|
696 |
+
return model_instance
|
697 |
+
|
698 |
+
except Exception as e:
|
699 |
+
print(f"模型載入出錯: {str(e)}")
|
700 |
+
print("嘗試使用基本載入方式...")
|
701 |
+
|
702 |
+
# 如果優化載入失敗,退回到基本載入方式
|
703 |
+
checkpoint = torch.load(model_path, map_location=device)
|
704 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
705 |
+
model_instance.eval()
|
706 |
+
return model_instance
|
707 |
+
|
708 |
# Initialize model
|
709 |
num_classes = len(dog_breeds)
|
710 |
|
|
|
711 |
model = BaseModel(num_classes=num_classes, device=device)
|
712 |
|
713 |
+
# 使用優化後的載入函數
|
714 |
+
model = load_model("124_best_model_dog.pth", model, device)
|
|
|
|
|
|
|
|
|
715 |
model.eval()
|
716 |
|
717 |
# Image preprocessing function
|