chagu-demo / falocon_api /embededGeneratorRAG.py
talexm
update RAG query improvements
73321dd
raw
history blame
5.12 kB
import os
import sqlite3
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import pipeline
from typing import List, Dict
class EmbeddingGenerator:
def __init__(self, model_name: str = "all-MiniLM-L6-v2", gen_model: str = "distilgpt2", db_path: str = "embeddings.db"):
self.model = SentenceTransformer(model_name)
self.generator = pipeline("text-generation", model=gen_model)
self.db_path = db_path
self._initialize_db()
print(f"Loaded embedding model: {model_name}")
print(f"Loaded generative model: {gen_model}")
def _initialize_db(self):
# Connect to SQLite database and create table
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
filename TEXT PRIMARY KEY,
content TEXT,
embedding BLOB
)
""")
self.conn.commit()
def generate_embedding(self, text: str) -> np.ndarray:
try:
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding
except Exception as e:
print(f"Error generating embedding: {str(e)}")
return np.array([])
def ingest_files(self, directory: str):
for filename in os.listdir(directory):
if filename.endswith(".txt"):
file_path = os.path.join(directory, filename)
with open(file_path, 'r') as f:
content = f.read()
embedding = self.generate_embedding(content)
self._store_embedding(filename, content, embedding)
def _store_embedding(self, filename: str, content: str, embedding: np.ndarray):
try:
self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)",
(filename, content, embedding.tobytes()))
self.conn.commit()
except Exception as e:
print(f"Error storing embedding: {str(e)}")
def load_embeddings(self) -> List[Dict]:
self.cursor.execute("SELECT filename, content, embedding FROM embeddings")
rows = self.cursor.fetchall()
documents = []
for filename, content, embedding_blob in rows:
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
documents.append({"filename": filename, "content": content, "embedding": embedding})
return documents
def compute_similarity(self, query_embedding: np.ndarray, document_embeddings: List[np.ndarray]) -> List[float]:
try:
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
return similarities.tolist()
except Exception as e:
print(f"Error computing similarity: {str(e)}")
return []
def find_most_similar(self, query: str, top_k: int = 5) -> List[Dict]:
query_embedding = self.generate_embedding(query)
documents = self.load_embeddings()
if query_embedding.size == 0 or len(documents) == 0:
print("Error: Invalid embeddings or no documents found.")
return []
document_embeddings = [doc["embedding"] for doc in documents]
similarities = self.compute_similarity(query_embedding, document_embeddings)
ranked_results = sorted(
[{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim}
for doc, sim in zip(documents, similarities)],
key=lambda x: x["similarity"],
reverse=True
)
return ranked_results[:top_k]
def generate_response(self, query: str, top_k_docs: List[str]) -> str:
# Combine the query with the retrieved documents for context
context = " ".join(top_k_docs)
input_text = f"Query: {query}\nContext: {context}\nAnswer:"
# Generate a response using the generative model
response = self.generator(input_text, max_length=1000, num_return_sequences=1)
return response[0]["generated_text"]
def find_most_similar_and_generate(self, query: str, top_k: int = 5) -> str:
top_k_results = self.find_most_similar(query, top_k)
top_k_docs = [result["content"] for result in top_k_results]
response = self.generate_response(query, top_k_docs)
return response
# Example Usage
if __name__ == "__main__":
# Initialize the embedding generator with RAG capabilities and ingest .txt files from the 'documents' directory
embedding_generator = EmbeddingGenerator()
embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
# Perform a search query with RAG response generation
query = "DROP TABLE reviews; SELECT * FROM confidential_data;"#"find user comments tt0118866"
response = embedding_generator.find_most_similar_and_generate(query)
print("Generated Response:")
print(response)