import os import sqlite3 import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from transformers import pipeline from typing import List, Dict class EmbeddingGenerator: def __init__(self, model_name: str = "all-MiniLM-L6-v2", gen_model: str = "distilgpt2", db_path: str = "embeddings.db"): self.model = SentenceTransformer(model_name) self.generator = pipeline("text-generation", model=gen_model) self.db_path = db_path self._initialize_db() print(f"Loaded embedding model: {model_name}") print(f"Loaded generative model: {gen_model}") def _initialize_db(self): # Connect to SQLite database and create table self.conn = sqlite3.connect(self.db_path) self.cursor = self.conn.cursor() self.cursor.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( filename TEXT PRIMARY KEY, content TEXT, embedding BLOB ) """) self.conn.commit() def generate_embedding(self, text: str) -> np.ndarray: try: embedding = self.model.encode(text, convert_to_numpy=True) return embedding except Exception as e: print(f"Error generating embedding: {str(e)}") return np.array([]) def ingest_files(self, directory: str): for filename in os.listdir(directory): if filename.endswith(".txt"): file_path = os.path.join(directory, filename) with open(file_path, 'r') as f: content = f.read() embedding = self.generate_embedding(content) self._store_embedding(filename, content, embedding) def _store_embedding(self, filename: str, content: str, embedding: np.ndarray): try: self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)", (filename, content, embedding.tobytes())) self.conn.commit() except Exception as e: print(f"Error storing embedding: {str(e)}") def load_embeddings(self) -> List[Dict]: self.cursor.execute("SELECT filename, content, embedding FROM embeddings") rows = self.cursor.fetchall() documents = [] for filename, content, embedding_blob in rows: embedding = np.frombuffer(embedding_blob, dtype=np.float32) documents.append({"filename": filename, "content": content, "embedding": embedding}) return documents def compute_similarity(self, query_embedding: np.ndarray, document_embeddings: List[np.ndarray]) -> List[float]: try: similarities = cosine_similarity([query_embedding], document_embeddings)[0] return similarities.tolist() except Exception as e: print(f"Error computing similarity: {str(e)}") return [] def find_most_similar(self, query: str, top_k: int = 5) -> List[Dict]: query_embedding = self.generate_embedding(query) documents = self.load_embeddings() if query_embedding.size == 0 or len(documents) == 0: print("Error: Invalid embeddings or no documents found.") return [] document_embeddings = [doc["embedding"] for doc in documents] similarities = self.compute_similarity(query_embedding, document_embeddings) ranked_results = sorted( [{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim} for doc, sim in zip(documents, similarities)], key=lambda x: x["similarity"], reverse=True ) return ranked_results[:top_k] def generate_response(self, query: str, top_k_docs: List[str]) -> str: # Combine the query with the retrieved documents for context context = " ".join(top_k_docs) input_text = f"Query: {query}\nContext: {context}\nAnswer:" # Generate a response using the generative model response = self.generator(input_text, max_length=1000, num_return_sequences=1) return response[0]["generated_text"] def find_most_similar_and_generate(self, query: str, top_k: int = 5) -> str: top_k_results = self.find_most_similar(query, top_k) top_k_docs = [result["content"] for result in top_k_results] response = self.generate_response(query, top_k_docs) return response # Example Usage if __name__ == "__main__": # Initialize the embedding generator with RAG capabilities and ingest .txt files from the 'documents' directory embedding_generator = EmbeddingGenerator() embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/")) # Perform a search query with RAG response generation query = "DROP TABLE reviews; SELECT * FROM confidential_data;"#"find user comments tt0118866" response = embedding_generator.find_most_similar_and_generate(query) print("Generated Response:") print(response)