File size: 8,787 Bytes
fbbb97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad762da
 
 
fbbb97a
 
 
 
f8ba7a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import re
import soundfile as sf
import torch
import torchaudio
import torchaudio.transforms as T
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoModel
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, StuffDocumentsChain, RetrievalQA
from langchain.llms import LlamaCpp
import gradio as gr

class PDFProcessor:
    def __init__(self, pdf_path):
        self.pdf_path = pdf_path

    def load_and_split_pdf(self):
        loader = PyPDFLoader(self.pdf_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
        docs = text_splitter.split_documents(documents)
        return docs

class FAISSManager:
    def __init__(self):
        self.vectorstore_cache = {}

    def build_faiss_index(self, docs):
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore = FAISS.from_documents(docs, embeddings)
        return vectorstore

    def save_faiss_index(self, vectorstore, file_path):
        vectorstore.save_local(file_path)
        print(f"Vectorstore saved to {file_path}")

    def load_faiss_index(self, file_path):
        if not os.path.exists(f"{file_path}/index.faiss") or not os.path.exists(f"{file_path}/index.pkl"):
            raise FileNotFoundError(f"Could not find FAISS index or metadata files in {file_path}")
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore = FAISS.load_local(file_path, embeddings, allow_dangerous_deserialization=True)
        print(f"Vectorstore loaded from {file_path}")
        return vectorstore

    def build_faiss_index_with_cache_and_file(self, pdf_processor, vectorstore_path):
        if os.path.exists(vectorstore_path):
            print(f"Loading vectorstore from file {vectorstore_path}")
            return self.load_faiss_index(vectorstore_path)

        print(f"Building new vectorstore for {pdf_processor.pdf_path}")
        docs = pdf_processor.load_and_split_pdf()
        vectorstore = self.build_faiss_index(docs)
        self.save_faiss_index(vectorstore, vectorstore_path)
        return vectorstore

class LLMChainFactory:
    def __init__(self, prompt_template):
        self.prompt_template = prompt_template

    def create_llm_chain(self, llm, max_tokens=80):
        prompt = PromptTemplate(template=self.prompt_template, input_variables=["documents", "question"])
        llm_chain = LLMChain(llm=llm, prompt=prompt)
        llm_chain.llm.max_tokens = max_tokens
        combine_documents_chain = StuffDocumentsChain(
            llm_chain=llm_chain,
            document_variable_name="documents"
        )
        return combine_documents_chain

class LLMManager:
    def __init__(self, model_path):
        self.llm = LlamaCpp(model_path=model_path)
        self.llm.max_tokens = 80

    def create_rag_chain(self, llm_chain_factory, vectorstore):
        retriever = vectorstore.as_retriever()
        combine_documents_chain = llm_chain_factory.create_llm_chain(self.llm)
        qa_chain = RetrievalQA(combine_documents_chain=combine_documents_chain, retriever=retriever)
        return qa_chain

    def main_rag_pipeline(self, pdf_processor, query, vectorstore_manager, vectorstore_file):
        vectorstore = vectorstore_manager.build_faiss_index_with_cache_and_file(pdf_processor, vectorstore_file)
        llm_chain_factory = LLMChainFactory(prompt_template="""You are a helpful AI. Based on the context below, answer the question politely.
        Context: {documents}
        Question: {question}
        Answer:""")
        rag_chain = self.create_rag_chain(llm_chain_factory, vectorstore)
        result = rag_chain.run(query)
        return result

class WhisperManager:
    def __init__(self):
        self.model_id = "openai/whisper-small"
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
        self.whisper_processor = WhisperProcessor.from_pretrained(self.model_id)
        self.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")

    def transcribe_speech(self, filepath):
        if not os.path.isfile(filepath):
            raise ValueError(f"Invalid file path: {filepath}")
        waveform, sample_rate = torchaudio.load(filepath)
        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
            waveform = resampler(waveform)
        input_features = self.whisper_processor(waveform.squeeze(), sampling_rate=target_sample_rate, return_tensors="pt").input_features
        generated_ids = self.whisper_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
        transcribed_text = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        cleaned_text = re.sub(r"<[^>]*>", "", transcribed_text).strip()
        return cleaned_text

class SpeechT5Manager:
    def __init__(self):
        self.SpeechT5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
        self.SpeechT5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
        self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
        self.speaker_embedding_model = AutoModel.from_pretrained("microsoft/speecht5_vc")
        embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
        self.pretrained_speaker_embeddings = torch.tensor(embeddings_dataset[7000]["xvector"]).unsqueeze(0)

    def text_to_speech(self, text, output_file="output_speechT5.wav"):
        inputs = self.SpeechT5_processor(text=[text], return_tensors="pt")
        speech = self.SpeechT5_model.generate_speech(inputs["input_ids"], self.pretrained_speaker_embeddings, vocoder=self.vocoder)
        sf.write(output_file, speech.numpy(), 16000)
        return output_file

# --- Gradio Interface ---
def asr_to_text(audio_file):
    transcribed_text = whisper_manager.transcribe_speech(audio_file)
    return transcribed_text

def process_with_llm_and_tts(transcribed_text):
    response_text = llm_manager.main_rag_pipeline(pdf_processor, transcribed_text, vectorstore_manager, vectorstore_file)
    audio_output = speech_manager.text_to_speech(response_text)
    return response_text, audio_output

# Instantiate Managers
pdf_processor = PDFProcessor('./files/LawsoftheGame2024_25.pdf')
vectorstore_manager = FAISSManager()
llm_manager = LLMManager(model_path="./files/mistral-7b-instruct-v0.2.Q2_K.gguf")
whisper_manager = WhisperManager()
speech_manager = SpeechT5Manager()
vectorstore_file = "./vectorstore_faiss"

# Define Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>RAG Powered Voice Assistant</h1>") #removed emojis
    gr.Markdown("<h1 style='text-align: center;'>Ask me anything about the rules of Football!</h1>")

    # Step 1: Audio input and ASR output
    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Speak your question")
        asr_output = gr.Textbox(label="ASR Output (Edit if necessary)", interactive=True)

    # Button to process audio (ASR)
    asr_button = gr.Button("1 - Transform Voice to Text")

    # Step 2: LLM Response and TTS output
    with gr.Row():
        llm_response = gr.Textbox(label="LLM Response")
        tts_audio_output = gr.Audio(label="TTS Audio")

    # Button to process text with LLM
    llm_button = gr.Button("2 - Submit Question")

    # When ASR button is clicked, the audio is transcribed
    asr_button.click(fn=asr_to_text, inputs=audio_input, outputs=asr_output)

    # When LLM button is clicked, the text is processed with the LLM and converted to speech
    llm_button.click(fn=process_with_llm_and_tts, inputs=asr_output, outputs=[llm_response, tts_audio_output])

    # Disclaimer
     gr.Markdown(
        "<p style='text-align: center; color: gray;'>This application runs on a machine with limited (but awesome) resources, so LLM completion may take up to 2 minutes.</p>"
    )
    gr.Markdown(
        "<p style='text-align: center; color: gray;'>Disclaimer: This application was developed solely for educational purposes to demonstrate AI capabilities and should not be used as a source of information or for any other purpose.</p>"
    )

demo.launch(debug=True)