DawnC commited on
Commit
b9b81ad
1 Parent(s): a0f2ca9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -40
app.py CHANGED
@@ -538,46 +538,34 @@ from urllib.parse import quote
538
  from ultralytics import YOLO
539
  import asyncio
540
  import traceback
 
 
541
 
542
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
543
- os.environ['HF_ZERO_GPU'] = '1' # 明確告訴系統我們要使用 ZeroGPU
544
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
545
 
 
546
  def get_device():
547
  print("Initializing device configuration...")
548
 
549
- # 特別針對 ZeroGPU 的檢測邏輯
550
- if 'HF_ZERO_GPU' in os.environ and torch.cuda.is_available():
551
- try:
552
- # 強制進行 CUDA 初始化
553
- torch.cuda.init()
554
- # 等待一小段時間讓系統完成初始化
555
- import time
556
- time.sleep(2)
557
-
558
  device = torch.device('cuda')
559
-
560
- # 執行一個小的測試來確認 GPU 功能
561
- test_tensor = torch.rand(1).to(device)
562
- _ = test_tensor * test_tensor
563
-
564
- print("ZeroGPU initialization successful")
565
- print(f"Using device: {device}")
566
- if torch.cuda.is_available():
567
- print(f"GPU: {torch.cuda.get_device_name(0)}")
568
-
569
  return device
570
-
571
- except Exception as e:
572
- print(f"ZeroGPU initialization failed: {str(e)}")
573
- print("Falling back to CPU")
574
- return torch.device('cpu')
575
- else:
576
- if not torch.cuda.is_available():
577
- print("CUDA not available, using CPU")
578
- elif 'HF_ZERO_GPU' not in os.environ:
579
- print("HF_ZERO_GPU not set, using CPU")
580
- return torch.device('cpu')
581
 
582
  device = get_device()
583
 
@@ -670,18 +658,60 @@ class BaseModel(nn.Module):
670
  logits = self.classifier(attended_features)
671
  return logits, attended_features
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  # Initialize model
674
  num_classes = len(dog_breeds)
675
 
676
- # Initialize base model
677
  model = BaseModel(num_classes=num_classes, device=device)
678
 
679
- # Load model path
680
- model_path = "124_best_model_dog.pth"
681
- checkpoint = torch.load(model_path, map_location=device)
682
-
683
- # Load model state
684
- model.load_state_dict(checkpoint['base_model'], strict=False)
685
  model.eval()
686
 
687
  # Image preprocessing function
 
538
  from ultralytics import YOLO
539
  import asyncio
540
  import traceback
541
+ import spaces
542
+ import torch.cuda.amp
543
 
544
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
545
+ # os.environ['HF_ZERO_GPU'] = '1' # 明確告訴系統我們要使用 ZeroGPU
546
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
547
 
548
+ @spaces.GPU
549
  def get_device():
550
  print("Initializing device configuration...")
551
 
552
+ try:
553
+ # 強制進行 CUDA 初始化
554
+ torch.cuda.init()
555
+ # 使用 mixed precision
556
+ torch.set_float32_matmul_precision('medium')
557
+
558
+ if torch.cuda.is_available():
 
 
559
  device = torch.device('cuda')
560
+ # 設置默認的 CUDA 設備
561
+ torch.cuda.set_device(device)
562
+ print(f"Successfully initialized CUDA device")
 
 
 
 
 
 
 
563
  return device
564
+ except Exception as e:
565
+ print(f"GPU initialization error: {str(e)}")
566
+
567
+ print("Using CPU fallback")
568
+ return torch.device('cpu')
 
 
 
 
 
 
569
 
570
  device = get_device()
571
 
 
658
  logits = self.classifier(attended_features)
659
  return logits, attended_features
660
 
661
+ def load_model(model_path, model_instance, device):
662
+ """
663
+ 優化的模型載入函數,支援 ZeroGPU 和混合精度計算
664
+
665
+ Args:
666
+ model_path: 模型檔案的路徑
667
+ model_instance: BaseModel 的實例
668
+ device: 計算設備(CPU 或 GPU)
669
+
670
+ Returns:
671
+ 載入權重後的模型實例
672
+ """
673
+ try:
674
+ print(f"正在將模型載入到設備: {device}")
675
+
676
+ # 使用混合精度計算來優化記憶體使用
677
+ with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
678
+ # 載入檢查點,使用 weights_only=True 來避免警告
679
+ checkpoint = torch.load(
680
+ model_path,
681
+ map_location=device,
682
+ weights_only=True
683
+ )
684
+
685
+ # 載入模型權重
686
+ model_instance.load_state_dict(checkpoint['base_model'], strict=False)
687
+
688
+ # 確保模型在正確的設備上
689
+ if device.type == 'cuda':
690
+ model_instance = model_instance.to(device)
691
+
692
+ # 設置為評估模式
693
+ model_instance.eval()
694
+
695
+ print("模型載入成功")
696
+ return model_instance
697
+
698
+ except Exception as e:
699
+ print(f"模型載入出錯: {str(e)}")
700
+ print("嘗試使用基本載入方式...")
701
+
702
+ # 如果優化載入失敗,退回到基本載入方式
703
+ checkpoint = torch.load(model_path, map_location=device)
704
+ model_instance.load_state_dict(checkpoint['base_model'], strict=False)
705
+ model_instance.eval()
706
+ return model_instance
707
+
708
  # Initialize model
709
  num_classes = len(dog_breeds)
710
 
 
711
  model = BaseModel(num_classes=num_classes, device=device)
712
 
713
+ # 使用優化後的載入函數
714
+ model = load_model("124_best_model_dog.pth", model, device)
 
 
 
 
715
  model.eval()
716
 
717
  # Image preprocessing function