guirnd's picture
Update app.py
ad762da verified
raw
history blame
8.79 kB
import os
import re
import soundfile as sf
import torch
import torchaudio
import torchaudio.transforms as T
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoModel
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, StuffDocumentsChain, RetrievalQA
from langchain.llms import LlamaCpp
import gradio as gr
class PDFProcessor:
def __init__(self, pdf_path):
self.pdf_path = pdf_path
def load_and_split_pdf(self):
loader = PyPDFLoader(self.pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
docs = text_splitter.split_documents(documents)
return docs
class FAISSManager:
def __init__(self):
self.vectorstore_cache = {}
def build_faiss_index(self, docs):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore
def save_faiss_index(self, vectorstore, file_path):
vectorstore.save_local(file_path)
print(f"Vectorstore saved to {file_path}")
def load_faiss_index(self, file_path):
if not os.path.exists(f"{file_path}/index.faiss") or not os.path.exists(f"{file_path}/index.pkl"):
raise FileNotFoundError(f"Could not find FAISS index or metadata files in {file_path}")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.load_local(file_path, embeddings, allow_dangerous_deserialization=True)
print(f"Vectorstore loaded from {file_path}")
return vectorstore
def build_faiss_index_with_cache_and_file(self, pdf_processor, vectorstore_path):
if os.path.exists(vectorstore_path):
print(f"Loading vectorstore from file {vectorstore_path}")
return self.load_faiss_index(vectorstore_path)
print(f"Building new vectorstore for {pdf_processor.pdf_path}")
docs = pdf_processor.load_and_split_pdf()
vectorstore = self.build_faiss_index(docs)
self.save_faiss_index(vectorstore, vectorstore_path)
return vectorstore
class LLMChainFactory:
def __init__(self, prompt_template):
self.prompt_template = prompt_template
def create_llm_chain(self, llm, max_tokens=80):
prompt = PromptTemplate(template=self.prompt_template, input_variables=["documents", "question"])
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain.llm.max_tokens = max_tokens
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="documents"
)
return combine_documents_chain
class LLMManager:
def __init__(self, model_path):
self.llm = LlamaCpp(model_path=model_path)
self.llm.max_tokens = 80
def create_rag_chain(self, llm_chain_factory, vectorstore):
retriever = vectorstore.as_retriever()
combine_documents_chain = llm_chain_factory.create_llm_chain(self.llm)
qa_chain = RetrievalQA(combine_documents_chain=combine_documents_chain, retriever=retriever)
return qa_chain
def main_rag_pipeline(self, pdf_processor, query, vectorstore_manager, vectorstore_file):
vectorstore = vectorstore_manager.build_faiss_index_with_cache_and_file(pdf_processor, vectorstore_file)
llm_chain_factory = LLMChainFactory(prompt_template="""You are a helpful AI. Based on the context below, answer the question politely.
Context: {documents}
Question: {question}
Answer:""")
rag_chain = self.create_rag_chain(llm_chain_factory, vectorstore)
result = rag_chain.run(query)
return result
class WhisperManager:
def __init__(self):
self.model_id = "openai/whisper-small"
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
self.whisper_processor = WhisperProcessor.from_pretrained(self.model_id)
self.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
def transcribe_speech(self, filepath):
if not os.path.isfile(filepath):
raise ValueError(f"Invalid file path: {filepath}")
waveform, sample_rate = torchaudio.load(filepath)
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
input_features = self.whisper_processor(waveform.squeeze(), sampling_rate=target_sample_rate, return_tensors="pt").input_features
generated_ids = self.whisper_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
transcribed_text = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
cleaned_text = re.sub(r"<[^>]*>", "", transcribed_text).strip()
return cleaned_text
class SpeechT5Manager:
def __init__(self):
self.SpeechT5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.SpeechT5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
self.speaker_embedding_model = AutoModel.from_pretrained("microsoft/speecht5_vc")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
self.pretrained_speaker_embeddings = torch.tensor(embeddings_dataset[7000]["xvector"]).unsqueeze(0)
def text_to_speech(self, text, output_file="output_speechT5.wav"):
inputs = self.SpeechT5_processor(text=[text], return_tensors="pt")
speech = self.SpeechT5_model.generate_speech(inputs["input_ids"], self.pretrained_speaker_embeddings, vocoder=self.vocoder)
sf.write(output_file, speech.numpy(), 16000)
return output_file
# --- Gradio Interface ---
def asr_to_text(audio_file):
transcribed_text = whisper_manager.transcribe_speech(audio_file)
return transcribed_text
def process_with_llm_and_tts(transcribed_text):
response_text = llm_manager.main_rag_pipeline(pdf_processor, transcribed_text, vectorstore_manager, vectorstore_file)
audio_output = speech_manager.text_to_speech(response_text)
return response_text, audio_output
# Instantiate Managers
pdf_processor = PDFProcessor('./files/LawsoftheGame2024_25.pdf')
vectorstore_manager = FAISSManager()
llm_manager = LLMManager(model_path="./files/mistral-7b-instruct-v0.2.Q2_K.gguf")
whisper_manager = WhisperManager()
speech_manager = SpeechT5Manager()
vectorstore_file = "./vectorstore_faiss"
# Define Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>RAG Powered Voice Assistant</h1>") #removed emojis
gr.Markdown("<h1 style='text-align: center;'>Ask me anything about the rules of Football!</h1>")
# Step 1: Audio input and ASR output
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Speak your question")
asr_output = gr.Textbox(label="ASR Output (Edit if necessary)", interactive=True)
# Button to process audio (ASR)
asr_button = gr.Button("1 - Transform Voice to Text")
# Step 2: LLM Response and TTS output
with gr.Row():
llm_response = gr.Textbox(label="LLM Response")
tts_audio_output = gr.Audio(label="TTS Audio")
# Button to process text with LLM
llm_button = gr.Button("2 - Submit Question")
# When ASR button is clicked, the audio is transcribed
asr_button.click(fn=asr_to_text, inputs=audio_input, outputs=asr_output)
# When LLM button is clicked, the text is processed with the LLM and converted to speech
llm_button.click(fn=process_with_llm_and_tts, inputs=asr_output, outputs=[llm_response, tts_audio_output])
# Disclaimer
gr.Markdown(
"<p style='text-align: center; color: gray;'>This application runs on a machine with limited (but awesome) resources, so LLM completion may take up to 2 minutes.</p>"
)
gr.Markdown(
"<p style='text-align: center; color: gray;'>Disclaimer: This application was developed solely for educational purposes to demonstrate AI capabilities and should not be used as a source of information or for any other purpose.</p>"
)
demo.launch(debug=True)