amtam0 commited on
Commit
eaffd42
1 Parent(s): ccd397c

add new files

Browse files
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ # Install necessary dependencies
4
+ RUN apt-get update && apt install -y python3-pip ffmpeg
5
+
6
+ # Set the working directory
7
+ WORKDIR /app
8
+
9
+ # Copy the app code and requirements filed
10
+ COPY stt.py /app
11
+
12
+ # Install the app dependencies
13
+ #RUN pip3 install --no-cache-dir -r requirements.txt
14
+ run pip3 install faster-whisper==0.6.0 flask==2.3.2
app2.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_cpp import Llama
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS, Chroma
4
+ from faster_whisper import WhisperModel
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+ import base64
9
+ import json
10
+ import chromadb
11
+ import requests
12
+ import gc, torch
13
+
14
+ GPU = False if torch.cuda.device_count()==0 else True
15
+ n_threads = os.cpu_count()//2
16
+ global llm
17
+
18
+ def load_llm(model_name):
19
+ try:
20
+ del llm
21
+ except:
22
+ pass
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+ llm = Llama(model_path=model_name,
26
+ n_threads=11, n_gpu_layers=80, n_ctx=3000)
27
+ return llm
28
+
29
+ def load_faiss_db():
30
+ new_db = FAISS.load_local("faiss_MH_c2000_o100", hf_embs)
31
+ return new_db
32
+
33
+ def load_chroma_db():
34
+ ABS_PATH = os.getcwd()#os.path.dirname(os.path.abspath(__file__))
35
+ DB_DIR = os.path.join(ABS_PATH, "chroma_MH_c1000_o0")
36
+ print("DB_DIR", DB_DIR)
37
+ client_settings = chromadb.config.Settings(
38
+ chroma_db_impl="duckdb+parquet",
39
+ persist_directory=DB_DIR,
40
+ anonymized_telemetry=False
41
+ )
42
+ vectorstore = Chroma(
43
+ collection_name="langchain_store",
44
+ embedding_function=hf_embs,
45
+ client_settings=client_settings,
46
+ persist_directory=DB_DIR,
47
+ )
48
+ return vectorstore
49
+
50
+ def init_prompt_tempalate(context, question):
51
+ prompt_template = f"""<s>[INST]
52
+ As a health insurance assistant, use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
53
+ {context}
54
+ Question: {question}
55
+ Concise answer in French:
56
+ [/INST]"""
57
+ prompt_template = f"""As a health insurance assistant, use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
58
+ {context}
59
+ Question: {question}
60
+ Concise answer in French:"""
61
+ prompt_template = f"""Answer the question based only on the following context:
62
+ {context}
63
+
64
+ Question: {question}
65
+
66
+ Answer in the following language: French
67
+ """
68
+
69
+ prompt_template = f"""<|system|>
70
+ Answer the question based only on the following context:
71
+ {context}</s>
72
+ <|user|>
73
+ {question}</s>
74
+ <|assistant|>
75
+ """
76
+ return prompt_template
77
+
78
+ def wav_to_base64(file_path):
79
+ base64_data = base64.b64encode(open(file_path, "rb").read()).decode("utf-8")
80
+ return base64_data
81
+
82
+ def search_llm(question, max_tokens=10, temp=0, k_chunks=1, top_k=40,
83
+ top_p=0.95):
84
+ results = {}
85
+ context = ""
86
+
87
+ new_db = new_db_faiss
88
+ # if db_type=="faiss":
89
+ # new_db = new_db_faiss
90
+ # else:
91
+ # new_db = new_db_chroma
92
+ docs = new_db.similarity_search_with_score(question,
93
+ k=int(k_chunks))
94
+ contexts = [el[0].page_content for el in docs]
95
+ scores = [el[1] for el in docs]
96
+ context = "\n".join(contexts)
97
+ score = sum(scores) / len(scores)
98
+ score = round(score, 3)
99
+ url = docs[0][0].metadata
100
+
101
+ prompt_template = init_prompt_tempalate(context, question)
102
+
103
+ output = llm(prompt_template,
104
+ max_tokens=int(max_tokens),
105
+ stop=["Question:", "\n"],
106
+ echo=True,
107
+ temperature=temp,
108
+ top_k=int(top_k),
109
+ top_p=top_p)
110
+ # first_reponse = output["choices"][0]["text"].split("answer in French:")[-1].strip()
111
+ first_reponse = output["choices"][0]["text"].split("<|assistant|>")[-1].strip()
112
+ results["Response"] = first_reponse
113
+ # results["prompt_template"] = prompt_template
114
+ results["context"] = context
115
+ results["source"] = url
116
+ results["context_score"] = score
117
+ return results["Response"], results["source"], results["context"], results["context_score"]
118
+
119
+ def stt(path):
120
+ injson = {}
121
+ injson["data"] = wav_to_base64(path)
122
+ results = requests.post(url="http://0.0.0.0:5566/api",
123
+ json=injson,
124
+ verify=False)
125
+ transcription = results.json()["transcription"]
126
+ query = transcription
127
+ query = transcription if "?" in transcription else transcription + "?"
128
+ return query
129
+
130
+ def STT_LLM(path, max_tokens, temp, k_chunks, top_k, top_p, db_type):
131
+ """
132
+ """
133
+ query = stt(path)
134
+ Response, url, context, contextScore = search_llm(query, max_tokens, temp, k_chunks, top_k, top_p)
135
+ return query, Response, url["source"], context, str(contextScore)
136
+
137
+ def LLM(content, max_tokens, temp, k_chunks, top_k, top_p, db_type):
138
+ Response, url, context, contextScore = search_llm(content, max_tokens, temp, k_chunks, top_k,
139
+ top_p)
140
+ url = url["source"]
141
+ return Response, url, context, str(contextScore)
142
+
143
+
144
+ embs_name = "sentence-transformers/all-mpnet-base-v2"
145
+ hf_embs = HuggingFaceEmbeddings(model_name=embs_name,
146
+ model_kwargs={"device": "cuda"})
147
+ new_db_chroma = load_faiss_db()
148
+ new_db_faiss = load_chroma_db()
149
+ ### Load models
150
+ #stt
151
+ wspr = WhisperModel("small", device="cuda" if GPU else "cpu", compute_type="int8")
152
+ #llm
153
+ model_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
154
+ model_name = "zephyr-7b-beta.Q4_K_M.gguf"
155
+
156
+ llm = load_llm(model_name)
157
+
158
+ demo = gr.Blocks()
159
+ with demo:
160
+ with gr.Tab(model_name):
161
+ with gr.Row():
162
+ with gr.Column():
163
+ with gr.Box():
164
+ content = gr.Text(label="Posez votre question")
165
+ audio_path = gr.Audio(source="microphone",
166
+ format="mp3",
167
+ type="filepath",
168
+ label="Posez votre question (Whisper-small)")
169
+ with gr.Row():
170
+ max_tokens = gr.Number(label="Max_tokens", value=100, maximum=1000, minimum=1)
171
+ temp = gr.Number(label="Temperature", value=0.1, maximum=1.0, minimum=0.0, step=0.1)
172
+ k_chunks = gr.Number(label="k_chunks", value=2, maximum=5, minimum=1)
173
+ top_k = gr.Number(label="top_k", value=100, maximum=1000, minimum=1)
174
+ top_p = gr.Number(label="top_p", value=0.95, maximum=1.0, minimum=0.0)
175
+ # with gr.Box():
176
+ # db_type = gr.Dropdown(choices=["faiss", "chromadb"], label="Vector DB", value="faiss")
177
+ # # llm_name = gr.Dropdown(choices=["vicuna-7b-v1.3.ggmlv3.q4_1.bin",
178
+ # # "vicuna-7b-v1.3.ggmlv3.q5_1.bin"],
179
+ # # label="llm", value="vicuna-7b-v1.3.ggmlv3.q4_1.bin")
180
+ # b3 = gr.Button("update model")
181
+ # # b3.click(load_llm, inputs=llm_name, outputs=None)
182
+ with gr.Column():
183
+ # transcription = gr.Text(label="transcription")
184
+ Response = gr.Text(label="Réponse")
185
+ url = gr.Text(label="url source")
186
+ context = gr.Text(label="contexte (chunks)")
187
+ contextScore = gr.Text(label="contexte score (L2 distance)")
188
+ with gr.Box():
189
+ b2 = gr.Button("reconnaissace vocale")
190
+ b1 = gr.Button("search llm")
191
+ b1.click(LLM, inputs=[content, max_tokens, temp, k_chunks, top_k, top_p], #db_type
192
+ outputs=[Response, url, context, contextScore])
193
+ b2.click(stt, inputs=audio_path, outputs=content)
194
+
195
+ # with gr.Tab("gptq"):
196
+ # with gr.Row():
197
+ # with gr.Column():
198
+ # with gr.Box():
199
+ # content = gr.Text(label="Posez votre question")
200
+ # audio_path = gr.Audio(source="microphone",
201
+ # format="mp3",
202
+ # type="filepath",
203
+ # label="Posez votre question (Whisper-small)")
204
+ # with gr.Row():
205
+ # max_tokens = gr.Number(label="Max_tokens", value=100, maximum=1000, minimum=1)
206
+ # temp = gr.Number(label="Temperature", value=0.1, maximum=1.0, minimum=0.0)
207
+ # k_chunks = gr.Number(label="k_chunks", value=2, maximum=3, minimum=1)
208
+ # top_k = gr.Number(label="top_k", value=100, maximum=1000, minimum=1)
209
+ # top_p = gr.Number(label="top_p", value=0.95, maximum=1.0, minimum=0.0)
210
+ # with gr.Box():
211
+ # db_type = gr.Dropdown(choices=["faiss", "chromadb"], label="Vector DB", value="faiss")
212
+ # llm_name = gr.Dropdown(choices=["llama-2-7b.ggmlv3.q4_1.bin",
213
+ # "vicuna-7b-v1.3.ggmlv3.q4_1.bin"],
214
+ # label="llm", value="llama-2-7b.ggmlv3.q4_1.bin")
215
+ # b3 = gr.Button("update model")
216
+ # # b3.click(stt, inputs=llm_name, outputs=None)
217
+ # with gr.Column():
218
+ # # transcription = gr.Text(label="transcription")
219
+ # Response = gr.Text(label="Réponse")
220
+ # url = gr.Text(label="url source")
221
+ # context = gr.Text(label="contexte (chunks)")
222
+ # contextScore = gr.Text(label="contexte score (L2 distance)")
223
+ # with gr.Box():
224
+ # b2 = gr.Button("reconnaissace vocale")
225
+ # b1 = gr.Button("search llm")
226
+ # b1.click(LLM, inputs=[content, max_tokens, temp, k_chunks, top_k, top_p, db_type],
227
+ # outputs=[Response, url, context, contextScore])
228
+ # b2.click(stt, inputs=audio_path, outputs=content)
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch(share=True, enable_queue=True, show_api=True)
build_db.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import xml.etree.ElementTree as ET
3
+ from langchain.document_loaders import UnstructuredURLLoader
4
+ from langchain.vectorstores import Chroma, FAISS
5
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ import chromadb
8
+ import logging
9
+ import os
10
+
11
+ # works only with sitemap.xml url ONLY
12
+
13
+ # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+
15
+ #Params to edit
16
+ global Chunk_size, Chunk_overlap
17
+ Chunk_size = 2000
18
+ Chunk_overlap = 100
19
+ Sitemap_url = "https://www.malakoffhumanis.com/sitemap.xml"
20
+
21
+ def langchain_web_scraper(sitemap_url, chunk_size=1000, chunk_overlap=100):
22
+ """
23
+ """
24
+ # Fetch the sitemap.xml file
25
+ response = requests.get(sitemap_url)
26
+ tree = ET.fromstring(response.content)
27
+ # Extract URLs from sitemap
28
+ urls = []
29
+ for url in tree.findall("{http://www.sitemaps.org/schemas/sitemap/0.9}url"):
30
+ loc = url.find("{http://www.sitemaps.org/schemas/sitemap/0.9}loc").text
31
+ # if "" in loc:
32
+ urls.append(loc)
33
+ print("len(urls)", len(urls))
34
+ # scraping
35
+ loaders = UnstructuredURLLoader(urls=urls)
36
+ data = loaders.load()
37
+
38
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
39
+ chunk_overlap=chunk_overlap
40
+ # separators=[" ", "\n"]
41
+ )
42
+
43
+ documents = text_splitter.split_documents(data)
44
+ return documents
45
+
46
+ def store_vdb_faiss(documents=[], hf_embs=None, save_path="faiss_MHCOM"):
47
+ """
48
+ """
49
+ db = FAISS.from_documents(documents, hf_embs)
50
+ db.save_local(save_path)
51
+
52
+ def store_vdb_chroma(documents=[], hf_embs=None, save_path="chroma_MHCOM"):
53
+ """
54
+ """
55
+ ABS_PATH = os.path.dirname(os.path.abspath(__file__))
56
+ DB_DIR = os.path.join(ABS_PATH, save_path)
57
+
58
+ client_settings = chromadb.config.Settings(
59
+ chroma_db_impl="duckdb+parquet",
60
+ persist_directory=DB_DIR,
61
+ anonymized_telemetry=False
62
+ )
63
+ vectorstore = Chroma(
64
+ collection_name="langchain_store",
65
+ embedding_function=hf_embs,
66
+ client_settings=client_settings,
67
+ persist_directory=DB_DIR,
68
+ )
69
+ vectorstore.add_documents(documents=documents, embedding=hf_embs)
70
+ vectorstore.persist()
71
+
72
+
73
+ def main():
74
+ print("scrapping website")
75
+ documents = langchain_web_scraper(sitemap_url=Sitemap_url,
76
+ chunk_size=Chunk_size,
77
+ chunk_overlap=Chunk_overlap)
78
+ #store in vector DB FAISS
79
+ print("load embeddings")
80
+ embeddings_model_name = "sentence-transformers/all-mpnet-base-v2"
81
+ hf_embs = HuggingFaceEmbeddings(model_name=embeddings_model_name,
82
+ model_kwargs={"device": "cuda"})
83
+
84
+ print("storing chunks in vector db")
85
+ store_vdb_faiss(documents=documents,
86
+ hf_embs=hf_embs,
87
+ save_path="faiss_MH_c{}_o{}".format(str(Chunk_size),
88
+ str(Chunk_overlap)))
89
+
90
+ # store_vdb_chroma(documents=documents,
91
+ # hf_embs=hf_embs,
92
+ # save_path="chroma_MH_c{}_o{}".format(str(Chunk_size),
93
+ # str(Chunk_overlap)))
94
+
95
+ if __name__ == '__main__':
96
+ main()
downld_models_local.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ repo_id = "TheBloke/zephyr-7B-beta-GGUF"
4
+ model_name = "zephyr-7b-beta.Q4_K_M.gguf"
5
+
6
+ local_dir="./"#"/mnt/ssd1/MH/AMINE/NLPBANK/localchatbot"
7
+
8
+ hf_hub_download(repo_id=repo_id,
9
+ filename=model_name,
10
+ local_dir=local_dir,
11
+ local_dir_use_symlinks=False)
faiss_MH_c2000_o100/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8560a800adb7152772c1dd3041e45e0c6842a89b3f8d4c8a5629596de225d519
3
+ size 11461677
faiss_MH_c2000_o100/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93516c28a21f6317fe230a2f92a21e6de8f6fb628a65297cd8fca8c569c95501
3
+ size 6453479
stt.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request
2
+ import json
3
+ import base64
4
+ from faster_whisper import WhisperModel
5
+ import tempfile
6
+ import os
7
+
8
+ def base64_to_wav(base64_data, save_path):
9
+ wav_data = base64.b64decode(base64_data)
10
+ with open(save_path, 'wb') as file:
11
+ file.write(wav_data)
12
+
13
+ app = Flask(__name__)
14
+
15
+ GPU = True
16
+ wspr = WhisperModel("small", device="cuda" if GPU else "cpu", compute_type="int8")
17
+
18
+ @app.route('/api', methods=['GET' ,'POST'])
19
+ def STT():
20
+ if request.method == 'POST':
21
+ result = {}
22
+ audio_data = request.data
23
+
24
+ # Create a unique filename in the temporary directory
25
+ temp_dir = tempfile.gettempdir()
26
+ temp_file = tempfile.NamedTemporaryFile(suffix='.mp3', dir=temp_dir, delete=False)
27
+ save_path = temp_file.name
28
+ # save_path = "temp.wav"
29
+
30
+ # save_path = "temp.wav"
31
+ base64_to_wav(audio_data, save_path)
32
+
33
+ segments, info = wspr.transcribe(save_path, beam_size=5, language="fr")
34
+ texts = [el.text.strip() for el in segments]
35
+ transcription = texts[0]
36
+ query = transcription
37
+ result["transcription"] = query
38
+ if os.path.exists(save_path):
39
+ os.remove(save_path)
40
+ return result
41
+ elif request.method == 'GET':
42
+ # Add your test logic here
43
+ return "API is working correctly!"
44
+
45
+ if __name__ == '__main__':
46
+ app.run(host='0.0.0.0', port=5566, debug=True)