DawnC commited on
Commit
f5a076e
1 Parent(s): de441a9

Update smart_breed_matcher.py

Browse files
Files changed (1) hide show
  1. smart_breed_matcher.py +28 -2
smart_breed_matcher.py CHANGED
@@ -18,19 +18,40 @@ def gpu_init_wrapper(func):
18
  return func(*args, **kwargs)
19
  return wrapper
20
 
21
-
22
  class SmartBreedMatcher:
 
 
 
 
 
 
 
 
 
 
 
 
23
  def __init__(self, dog_data: List[Tuple]):
24
  self.dog_data = dog_data
25
- self.model = SentenceTransformer('all-mpnet-base-v2')
26
  self._embedding_cache = {}
27
  self._clear_cache()
28
 
 
 
 
 
 
29
  def _clear_cache(self):
30
  self._embedding_cache = {}
31
 
32
 
 
33
  def _get_cached_embedding(self, text: str) -> torch.Tensor:
 
 
 
 
34
  if text not in self._embedding_cache:
35
  self._embedding_cache[text] = self.model.encode(text)
36
  return self._embedding_cache[text]
@@ -75,6 +96,8 @@ class SmartBreedMatcher:
75
  List[Tuple[str, float]]: 相似品種列表,包含品種名稱和相似度分數
76
  """
77
  try:
 
 
78
  target_breed = next((breed for breed in self.dog_data if breed[1] == breed_name), None)
79
  if not target_breed:
80
  return []
@@ -868,8 +891,11 @@ class SmartBreedMatcher:
868
  }
869
 
870
  @gpu_init_wrapper
 
871
  def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
872
  try:
 
 
873
  # 獲取場景權重
874
  weights = self._detect_scenario(description)
875
  matches = []
 
18
  return func(*args, **kwargs)
19
  return wrapper
20
 
 
21
  class SmartBreedMatcher:
22
+ def _safe_prediction(self, func):
23
+ @wraps(func)
24
+ def wrapper(*args, **kwargs):
25
+ try:
26
+ return func(*args, **kwargs)
27
+ except RuntimeError as e:
28
+ if "CUDA" in str(e):
29
+ print("GPU 操作失敗,嘗試使用 CPU")
30
+ return func(*args, **kwargs)
31
+ raise
32
+ return wrapper
33
+
34
  def __init__(self, dog_data: List[Tuple]):
35
  self.dog_data = dog_data
36
+ self.model = None
37
  self._embedding_cache = {}
38
  self._clear_cache()
39
 
40
+ def _initialize_model(self):
41
+ """延遲初始化模型,只在需要時才創建"""
42
+ if self.model is None:
43
+ self.model = SentenceTransformer('all-mpnet-base-v2')
44
+
45
  def _clear_cache(self):
46
  self._embedding_cache = {}
47
 
48
 
49
+ @spaces.GPU
50
  def _get_cached_embedding(self, text: str) -> torch.Tensor:
51
+ """使用 GPU 裝飾器確保在正確的時機初始化 CUDA"""
52
+ if self.model is None:
53
+ self._initialize_model()
54
+
55
  if text not in self._embedding_cache:
56
  self._embedding_cache[text] = self.model.encode(text)
57
  return self._embedding_cache[text]
 
96
  List[Tuple[str, float]]: 相似品種列表,包含品種名稱和相似度分數
97
  """
98
  try:
99
+ if self.model is None:
100
+ self._initialize_model()
101
  target_breed = next((breed for breed in self.dog_data if breed[1] == breed_name), None)
102
  if not target_breed:
103
  return []
 
891
  }
892
 
893
  @gpu_init_wrapper
894
+ @_safe_prediction
895
  def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
896
  try:
897
+ if self.model is None:
898
+ self._initialize_model()
899
  # 獲取場景權重
900
  weights = self._detect_scenario(description)
901
  matches = []