Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import torch.nn.functional as F | |
import json | |
from typing import List, Dict | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("models/tokenizer") | |
model = AutoModel.from_pretrained("models/model") | |
return tokenizer, model | |
def load_data(file_path: str = "data.json") -> List[Dict]: | |
with open(file_path, "r") as f: | |
data = json.load(f) | |
flattened_courses = [] | |
for course_category in data["courses"]: | |
for subcourse in course_category["subcourses"]: | |
if not subcourse or not all( | |
key in subcourse for key in ["name", "description", "link"] | |
): | |
continue | |
flattened_courses.append( | |
{ | |
"course_type": course_category["course_type"], | |
"name": subcourse["name"], | |
"description": subcourse["description"], | |
"link": subcourse["link"], | |
"content": f"{course_category['course_type']} - {subcourse['name']}: {subcourse['description']}", | |
} | |
) | |
return flattened_courses | |
def get_embedding(text: str, tokenizer, model) -> torch.Tensor: | |
inputs = tokenizer( | |
text, return_tensors="pt", truncation=True, max_length=512, padding=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
attention_mask = inputs["attention_mask"] | |
token_embeddings = outputs.last_hidden_state | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def precompute_embeddings(documents: List[Dict], tokenizer, model) -> torch.Tensor: | |
embeddings = [] | |
for doc in documents: | |
embedding = get_embedding(doc["content"], tokenizer, model) | |
embeddings.append(embedding) | |
return torch.cat(embeddings, dim=0) | |
def semantic_search( | |
query: str, | |
doc_embeddings: torch.Tensor, | |
documents: List[Dict], | |
tokenizer, | |
model, | |
top_k: int = 3, | |
) -> List[Dict]: | |
query_embedding = get_embedding(query, tokenizer, model) | |
similarities = F.cosine_similarity(query_embedding, doc_embeddings) | |
top_k_indices = torch.topk(similarities, min(top_k, len(documents))).indices | |
results = [] | |
for idx in top_k_indices: | |
doc = documents[idx.item()] | |
results.append( | |
{ | |
"course_type": doc["course_type"], | |
"name": doc["name"], | |
"description": doc["description"], | |
"link": doc["link"], | |
} | |
) | |
return results | |
try: | |
print("Loading model and tokenizer...") | |
tokenizer, model = load_model_and_tokenizer() | |
print("Loading documents...") | |
documents = load_data() | |
if not documents: | |
raise ValueError("No valid courses found in the data file") | |
print("Precomputing embeddings...") | |
doc_embeddings = precompute_embeddings(documents, tokenizer, model) | |
print("Initialization complete!") | |
except Exception as e: | |
print(f"Error during initialization: {str(e)}") | |
raise | |
def search_interface(query: str) -> str: | |
if not query.strip(): | |
return "Please enter a search query." | |
results = semantic_search(query, doc_embeddings, documents, tokenizer, model) | |
output = "# Search Results:\n\n" | |
for i, result in enumerate(results, 1): | |
output += f"## {i}. {result['course_type']} - {result['name']}\n" | |
output += f"**Description:** {result['description']}\n" | |
output += f"_[Link to Course]({result['link']})_\n\n" | |
return output | |
app = gr.Interface( | |
fn=search_interface, | |
inputs=gr.Textbox( | |
lines=2, | |
placeholder="Enter your search query here (e.g., 'machine learning', 'python for beginners', 'deep learning')", | |
), | |
outputs=gr.Markdown( | |
value="## Search Results will be displayed here.", | |
line_breaks=True, | |
label="Search Results", | |
show_label=True, | |
), | |
title="Analytics Vidhya Course Search Engine", | |
description="Search for courses using semantic similarity. Results are ordered by relevance.", | |
allow_flagging="never", | |
) | |
app.launch() | |