opex792 commited on
Commit
414ad6b
·
verified ·
1 Parent(s): 94d93d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -12,6 +12,7 @@ from urllib.parse import urlparse
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
 
15
 
16
  # Настройка логирования
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -31,12 +32,20 @@ db_params = {
31
  "sslmode": "require"
32
  }
33
 
34
- # Загружаем модель
35
  model_name = "BAAI/bge-m3"
36
  logging.info(f"Загрузка модели {model_name}...")
37
  model = SentenceTransformer(model_name)
38
  logging.info("Модель загружена успешно.")
39
 
 
 
 
 
 
 
 
 
40
  # Имена таблиц
41
  embeddings_table = "movie_embeddings"
42
  query_cache_table = "query_cache"
@@ -207,7 +216,7 @@ def process_batch(batch):
207
 
208
  try:
209
  for movie in batch:
210
- embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
211
  string_crc32 = calculate_crc32(embedding_string)
212
 
213
  # Проверяем существующий эмбеддинг
@@ -289,6 +298,24 @@ def get_movie_embeddings(conn):
289
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
290
  return movie_embeddings
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def search_movies(query, top_k=20):
293
  """Выполняет поиск фильмов по запросу."""
294
  global search_in_progress
@@ -332,23 +359,26 @@ def search_movies(query, top_k=20):
332
  FROM {embeddings_table} m, query_embedding
333
  ORDER BY similarity DESC
334
  LIMIT %s
335
- """, (query_crc32, top_k))
336
 
337
  results = cur.fetchall()
338
- logging.info(f"Найдено {len(results)} результатов поиска.")
339
  except Exception as e:
340
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
341
  results = []
342
 
 
 
 
343
  output = ""
344
- for movie_id, similarity in results:
345
  # Находим фильм по ID
346
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
347
  if movie:
348
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
349
- output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
350
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
351
- output += f"<p><strong>Релевантность:</strong> {similarity:.4f}</p>\n"
352
  output += "<hr>\n"
353
 
354
  search_time = time.time() - start_time
 
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
15
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
 
17
  # Настройка логирования
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
32
  "sslmode": "require"
33
  }
34
 
35
+ # Загружаем модель эмбеддингов
36
  model_name = "BAAI/bge-m3"
37
  logging.info(f"Загрузка модели {model_name}...")
38
  model = SentenceTransformer(model_name)
39
  logging.info("Модель загружена успешно.")
40
 
41
+ # Загружаем модель реранкера
42
+ reranker_name = 'BAAI/bge-reranker-v2-m3'
43
+ logging.info(f"Загрузка модели реранкера {reranker_name}...")
44
+ reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
45
+ reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
46
+ reranker_model.eval()
47
+ logging.info("Модель реранкера загружена успешно.")
48
+
49
  # Имена таблиц
50
  embeddings_table = "movie_embeddings"
51
  query_cache_table = "query_cache"
 
216
 
217
  try:
218
  for movie in batch:
219
+ embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
220
  string_crc32 = calculate_crc32(embedding_string)
221
 
222
  # Проверяем существующий эмбеддинг
 
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
+ def rerank_results(query, results):
302
+ """Переранжирует результаты поиска с помощью реранкера."""
303
+ pairs = []
304
+ movie_ids = []
305
+ for movie_id, _ in results:
306
+ movie = next((m for m in movies_data if m['id'] == movie_id), None)
307
+ if movie:
308
+ movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
309
+ pairs.append([query, movie_info])
310
+ movie_ids.append(movie_id)
311
+
312
+ with torch.no_grad():
313
+ inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
314
+ scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
315
+
316
+ reranked_results = sorted(zip(movie_ids, scores.tolist()), key=lambda x: x[1], reverse=True)
317
+ return reranked_results
318
+
319
  def search_movies(query, top_k=20):
320
  """Выполняет поиск фильмов по запросу."""
321
  global search_in_progress
 
359
  FROM {embeddings_table} m, query_embedding
360
  ORDER BY similarity DESC
361
  LIMIT %s
362
+ """, (query_crc32, top_k * 2)) # Увеличиваем лимит для последующего переранжирования
363
 
364
  results = cur.fetchall()
365
+ logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
366
  except Exception as e:
367
  logging.error(f"Ошибка при выполнении поискового запроса: {e}")
368
  results = []
369
 
370
+ # Переранжируем результаты
371
+ reranked_results = rerank_results(query, results)
372
+
373
  output = ""
374
+ for movie_id, score in reranked_results[:top_k]:
375
  # Находим фильм по ID
376
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
377
  if movie:
378
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
379
+ output += f"<p><strong>Жанры:</strong> {movie['genresList']}</p>\n"
380
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
381
+ output += f"<p><strong>Релевантность (reranker score):</strong> {score:.4f}</p>\n"
382
  output += "<hr>\n"
383
 
384
  search_time = time.time() - start_time