Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from urllib.parse import urlparse
|
|
12 |
import logging
|
13 |
from sklearn.preprocessing import normalize
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
15 |
|
16 |
# Настройка логирования
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -31,12 +32,20 @@ db_params = {
|
|
31 |
"sslmode": "require"
|
32 |
}
|
33 |
|
34 |
-
# Загружаем модель
|
35 |
model_name = "BAAI/bge-m3"
|
36 |
logging.info(f"Загрузка модели {model_name}...")
|
37 |
model = SentenceTransformer(model_name)
|
38 |
logging.info("Модель загружена успешно.")
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Имена таблиц
|
41 |
embeddings_table = "movie_embeddings"
|
42 |
query_cache_table = "query_cache"
|
@@ -207,7 +216,7 @@ def process_batch(batch):
|
|
207 |
|
208 |
try:
|
209 |
for movie in batch:
|
210 |
-
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['
|
211 |
string_crc32 = calculate_crc32(embedding_string)
|
212 |
|
213 |
# Проверяем существующий эмбеддинг
|
@@ -289,6 +298,24 @@ def get_movie_embeddings(conn):
|
|
289 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
290 |
return movie_embeddings
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
def search_movies(query, top_k=20):
|
293 |
"""Выполняет поиск фильмов по запросу."""
|
294 |
global search_in_progress
|
@@ -332,23 +359,26 @@ def search_movies(query, top_k=20):
|
|
332 |
FROM {embeddings_table} m, query_embedding
|
333 |
ORDER BY similarity DESC
|
334 |
LIMIT %s
|
335 |
-
""", (query_crc32, top_k))
|
336 |
|
337 |
results = cur.fetchall()
|
338 |
-
logging.info(f"Найдено {len(results)} результатов поиска.")
|
339 |
except Exception as e:
|
340 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
341 |
results = []
|
342 |
|
|
|
|
|
|
|
343 |
output = ""
|
344 |
-
for movie_id,
|
345 |
# Находим фильм по ID
|
346 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
347 |
if movie:
|
348 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
349 |
-
output += f"<p><strong>Жанры:</strong> {movie['
|
350 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
351 |
-
output += f"<p><strong
|
352 |
output += "<hr>\n"
|
353 |
|
354 |
search_time = time.time() - start_time
|
|
|
12 |
import logging
|
13 |
from sklearn.preprocessing import normalize
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
16 |
|
17 |
# Настройка логирования
|
18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
32 |
"sslmode": "require"
|
33 |
}
|
34 |
|
35 |
+
# Загружаем модель эмбеддингов
|
36 |
model_name = "BAAI/bge-m3"
|
37 |
logging.info(f"Загрузка модели {model_name}...")
|
38 |
model = SentenceTransformer(model_name)
|
39 |
logging.info("Модель загружена успешно.")
|
40 |
|
41 |
+
# Загружаем модель реранкера
|
42 |
+
reranker_name = 'BAAI/bge-reranker-v2-m3'
|
43 |
+
logging.info(f"Загрузка модели реранкера {reranker_name}...")
|
44 |
+
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
|
45 |
+
reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
|
46 |
+
reranker_model.eval()
|
47 |
+
logging.info("Модель реранкера загружена успешно.")
|
48 |
+
|
49 |
# Имена таблиц
|
50 |
embeddings_table = "movie_embeddings"
|
51 |
query_cache_table = "query_cache"
|
|
|
216 |
|
217 |
try:
|
218 |
for movie in batch:
|
219 |
+
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
|
220 |
string_crc32 = calculate_crc32(embedding_string)
|
221 |
|
222 |
# Проверяем существующий эмбеддинг
|
|
|
298 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
299 |
return movie_embeddings
|
300 |
|
301 |
+
def rerank_results(query, results):
|
302 |
+
"""Переранжирует результаты поиска с помощью реранкера."""
|
303 |
+
pairs = []
|
304 |
+
movie_ids = []
|
305 |
+
for movie_id, _ in results:
|
306 |
+
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
307 |
+
if movie:
|
308 |
+
movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
|
309 |
+
pairs.append([query, movie_info])
|
310 |
+
movie_ids.append(movie_id)
|
311 |
+
|
312 |
+
with torch.no_grad():
|
313 |
+
inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
314 |
+
scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
|
315 |
+
|
316 |
+
reranked_results = sorted(zip(movie_ids, scores.tolist()), key=lambda x: x[1], reverse=True)
|
317 |
+
return reranked_results
|
318 |
+
|
319 |
def search_movies(query, top_k=20):
|
320 |
"""Выполняет поиск фильмов по запросу."""
|
321 |
global search_in_progress
|
|
|
359 |
FROM {embeddings_table} m, query_embedding
|
360 |
ORDER BY similarity DESC
|
361 |
LIMIT %s
|
362 |
+
""", (query_crc32, top_k * 2)) # Увеличиваем лимит для последующего переранжирования
|
363 |
|
364 |
results = cur.fetchall()
|
365 |
+
logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
|
366 |
except Exception as e:
|
367 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
368 |
results = []
|
369 |
|
370 |
+
# Переранжируем результаты
|
371 |
+
reranked_results = rerank_results(query, results)
|
372 |
+
|
373 |
output = ""
|
374 |
+
for movie_id, score in reranked_results[:top_k]:
|
375 |
# Находим фильм по ID
|
376 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
377 |
if movie:
|
378 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
379 |
+
output += f"<p><strong>Жанры:</strong> {movie['genresList']}</p>\n"
|
380 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
381 |
+
output += f"<p><strong>Релевантность (reranker score):</strong> {score:.4f}</p>\n"
|
382 |
output += "<hr>\n"
|
383 |
|
384 |
search_time = time.time() - start_time
|