Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load a smaller portion of the dataset | |
dataset = load_dataset("MongoDB/embedded_movies", split='train[:80%]') | |
dataset_df = pd.DataFrame(dataset) | |
# Data cleaning and preprocessing | |
dataset_df = dataset_df.dropna(subset=["fullplot"]).reset_index(drop=True) | |
dataset_df = dataset_df.drop(columns=["plot_embedding"]) | |
# Load a smaller embedding model | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
def get_embedding(text: str) -> list: | |
if not text.strip(): | |
print("Attempted to get embedding for empty text.") | |
return [] | |
embedding = embedding_model.encode(text) | |
return embedding.tolist() | |
# Generate embeddings for all plots | |
all_embeddings = [] | |
batch_size = 32 | |
for i in range(0, len(dataset_df), batch_size): | |
batch = dataset_df['fullplot'].iloc[i:i+batch_size].tolist() | |
batch_embeddings = embedding_model.encode(batch) | |
all_embeddings.extend(batch_embeddings) | |
# Add embeddings to the DataFrame | |
dataset_df['embedding'] = all_embeddings | |
print("Embeddings generated and added to DataFrame") | |
def vector_search(user_query): | |
query_embedding = get_embedding(user_query) | |
if not query_embedding: | |
return "Invalid query or embedding generation failed." | |
similarities = cosine_similarity([query_embedding], list(dataset_df['embedding']))[0] | |
top_indices = similarities.argsort()[-3:][::-1] | |
results = [] | |
for idx in top_indices: | |
results.append({ | |
"title": dataset_df.iloc[idx]["title"], | |
"fullplot": dataset_df.iloc[idx]["fullplot"], | |
"genres": dataset_df.iloc[idx]["genres"], | |
"score": similarities[idx] | |
}) | |
return results | |
def get_search_result(query): | |
get_knowledge = vector_search(query) | |
search_result = "" | |
for result in get_knowledge: | |
search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:150]}...\n\n" | |
return search_result | |
def generate_response(query): | |
source_information = get_search_result(query) | |
response = f"Based on your query '{query}', here are some movie recommendations:\n\n{source_information}\nThese movies match your query based on their plot summaries and genres. Let me know if you'd like more information about any of them!" | |
return response | |
def query_movie_db(user_query): | |
return generate_response(user_query) | |
description_and_article = """ | |
Ask this bot to recommend you a movie. | |
Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot. | |
Note that the bot provides concise recommendations based on a limited dataset to ensure optimal performance. | |
""" | |
iface = gr.Interface( | |
fn=query_movie_db, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."), | |
outputs="text", | |
title="Movie Recommendation Bot", | |
description=description_and_article, | |
examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]] | |
) | |
if __name__ == "__main__": | |
iface.launch() |