add new files
Browse files- Dockerfile +14 -0
- app2.py +231 -0
- build_db.py +96 -0
- downld_models_local.py +11 -0
- faiss_MH_c2000_o100/index.faiss +3 -0
- faiss_MH_c2000_o100/index.pkl +3 -0
- stt.py +46 -0
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
2 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
3 |
+
# Install necessary dependencies
|
4 |
+
RUN apt-get update && apt install -y python3-pip ffmpeg
|
5 |
+
|
6 |
+
# Set the working directory
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
# Copy the app code and requirements filed
|
10 |
+
COPY stt.py /app
|
11 |
+
|
12 |
+
# Install the app dependencies
|
13 |
+
#RUN pip3 install --no-cache-dir -r requirements.txt
|
14 |
+
run pip3 install faster-whisper==0.6.0 flask==2.3.2
|
app2.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_cpp import Llama
|
2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
from langchain.vectorstores import FAISS, Chroma
|
4 |
+
from faster_whisper import WhisperModel
|
5 |
+
import os
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
import base64
|
9 |
+
import json
|
10 |
+
import chromadb
|
11 |
+
import requests
|
12 |
+
import gc, torch
|
13 |
+
|
14 |
+
GPU = False if torch.cuda.device_count()==0 else True
|
15 |
+
n_threads = os.cpu_count()//2
|
16 |
+
global llm
|
17 |
+
|
18 |
+
def load_llm(model_name):
|
19 |
+
try:
|
20 |
+
del llm
|
21 |
+
except:
|
22 |
+
pass
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
gc.collect()
|
25 |
+
llm = Llama(model_path=model_name,
|
26 |
+
n_threads=11, n_gpu_layers=80, n_ctx=3000)
|
27 |
+
return llm
|
28 |
+
|
29 |
+
def load_faiss_db():
|
30 |
+
new_db = FAISS.load_local("faiss_MH_c2000_o100", hf_embs)
|
31 |
+
return new_db
|
32 |
+
|
33 |
+
def load_chroma_db():
|
34 |
+
ABS_PATH = os.getcwd()#os.path.dirname(os.path.abspath(__file__))
|
35 |
+
DB_DIR = os.path.join(ABS_PATH, "chroma_MH_c1000_o0")
|
36 |
+
print("DB_DIR", DB_DIR)
|
37 |
+
client_settings = chromadb.config.Settings(
|
38 |
+
chroma_db_impl="duckdb+parquet",
|
39 |
+
persist_directory=DB_DIR,
|
40 |
+
anonymized_telemetry=False
|
41 |
+
)
|
42 |
+
vectorstore = Chroma(
|
43 |
+
collection_name="langchain_store",
|
44 |
+
embedding_function=hf_embs,
|
45 |
+
client_settings=client_settings,
|
46 |
+
persist_directory=DB_DIR,
|
47 |
+
)
|
48 |
+
return vectorstore
|
49 |
+
|
50 |
+
def init_prompt_tempalate(context, question):
|
51 |
+
prompt_template = f"""<s>[INST]
|
52 |
+
As a health insurance assistant, use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
53 |
+
{context}
|
54 |
+
Question: {question}
|
55 |
+
Concise answer in French:
|
56 |
+
[/INST]"""
|
57 |
+
prompt_template = f"""As a health insurance assistant, use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
58 |
+
{context}
|
59 |
+
Question: {question}
|
60 |
+
Concise answer in French:"""
|
61 |
+
prompt_template = f"""Answer the question based only on the following context:
|
62 |
+
{context}
|
63 |
+
|
64 |
+
Question: {question}
|
65 |
+
|
66 |
+
Answer in the following language: French
|
67 |
+
"""
|
68 |
+
|
69 |
+
prompt_template = f"""<|system|>
|
70 |
+
Answer the question based only on the following context:
|
71 |
+
{context}</s>
|
72 |
+
<|user|>
|
73 |
+
{question}</s>
|
74 |
+
<|assistant|>
|
75 |
+
"""
|
76 |
+
return prompt_template
|
77 |
+
|
78 |
+
def wav_to_base64(file_path):
|
79 |
+
base64_data = base64.b64encode(open(file_path, "rb").read()).decode("utf-8")
|
80 |
+
return base64_data
|
81 |
+
|
82 |
+
def search_llm(question, max_tokens=10, temp=0, k_chunks=1, top_k=40,
|
83 |
+
top_p=0.95):
|
84 |
+
results = {}
|
85 |
+
context = ""
|
86 |
+
|
87 |
+
new_db = new_db_faiss
|
88 |
+
# if db_type=="faiss":
|
89 |
+
# new_db = new_db_faiss
|
90 |
+
# else:
|
91 |
+
# new_db = new_db_chroma
|
92 |
+
docs = new_db.similarity_search_with_score(question,
|
93 |
+
k=int(k_chunks))
|
94 |
+
contexts = [el[0].page_content for el in docs]
|
95 |
+
scores = [el[1] for el in docs]
|
96 |
+
context = "\n".join(contexts)
|
97 |
+
score = sum(scores) / len(scores)
|
98 |
+
score = round(score, 3)
|
99 |
+
url = docs[0][0].metadata
|
100 |
+
|
101 |
+
prompt_template = init_prompt_tempalate(context, question)
|
102 |
+
|
103 |
+
output = llm(prompt_template,
|
104 |
+
max_tokens=int(max_tokens),
|
105 |
+
stop=["Question:", "\n"],
|
106 |
+
echo=True,
|
107 |
+
temperature=temp,
|
108 |
+
top_k=int(top_k),
|
109 |
+
top_p=top_p)
|
110 |
+
# first_reponse = output["choices"][0]["text"].split("answer in French:")[-1].strip()
|
111 |
+
first_reponse = output["choices"][0]["text"].split("<|assistant|>")[-1].strip()
|
112 |
+
results["Response"] = first_reponse
|
113 |
+
# results["prompt_template"] = prompt_template
|
114 |
+
results["context"] = context
|
115 |
+
results["source"] = url
|
116 |
+
results["context_score"] = score
|
117 |
+
return results["Response"], results["source"], results["context"], results["context_score"]
|
118 |
+
|
119 |
+
def stt(path):
|
120 |
+
injson = {}
|
121 |
+
injson["data"] = wav_to_base64(path)
|
122 |
+
results = requests.post(url="http://0.0.0.0:5566/api",
|
123 |
+
json=injson,
|
124 |
+
verify=False)
|
125 |
+
transcription = results.json()["transcription"]
|
126 |
+
query = transcription
|
127 |
+
query = transcription if "?" in transcription else transcription + "?"
|
128 |
+
return query
|
129 |
+
|
130 |
+
def STT_LLM(path, max_tokens, temp, k_chunks, top_k, top_p, db_type):
|
131 |
+
"""
|
132 |
+
"""
|
133 |
+
query = stt(path)
|
134 |
+
Response, url, context, contextScore = search_llm(query, max_tokens, temp, k_chunks, top_k, top_p)
|
135 |
+
return query, Response, url["source"], context, str(contextScore)
|
136 |
+
|
137 |
+
def LLM(content, max_tokens, temp, k_chunks, top_k, top_p, db_type):
|
138 |
+
Response, url, context, contextScore = search_llm(content, max_tokens, temp, k_chunks, top_k,
|
139 |
+
top_p)
|
140 |
+
url = url["source"]
|
141 |
+
return Response, url, context, str(contextScore)
|
142 |
+
|
143 |
+
|
144 |
+
embs_name = "sentence-transformers/all-mpnet-base-v2"
|
145 |
+
hf_embs = HuggingFaceEmbeddings(model_name=embs_name,
|
146 |
+
model_kwargs={"device": "cuda"})
|
147 |
+
new_db_chroma = load_faiss_db()
|
148 |
+
new_db_faiss = load_chroma_db()
|
149 |
+
### Load models
|
150 |
+
#stt
|
151 |
+
wspr = WhisperModel("small", device="cuda" if GPU else "cpu", compute_type="int8")
|
152 |
+
#llm
|
153 |
+
model_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
|
154 |
+
model_name = "zephyr-7b-beta.Q4_K_M.gguf"
|
155 |
+
|
156 |
+
llm = load_llm(model_name)
|
157 |
+
|
158 |
+
demo = gr.Blocks()
|
159 |
+
with demo:
|
160 |
+
with gr.Tab(model_name):
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column():
|
163 |
+
with gr.Box():
|
164 |
+
content = gr.Text(label="Posez votre question")
|
165 |
+
audio_path = gr.Audio(source="microphone",
|
166 |
+
format="mp3",
|
167 |
+
type="filepath",
|
168 |
+
label="Posez votre question (Whisper-small)")
|
169 |
+
with gr.Row():
|
170 |
+
max_tokens = gr.Number(label="Max_tokens", value=100, maximum=1000, minimum=1)
|
171 |
+
temp = gr.Number(label="Temperature", value=0.1, maximum=1.0, minimum=0.0, step=0.1)
|
172 |
+
k_chunks = gr.Number(label="k_chunks", value=2, maximum=5, minimum=1)
|
173 |
+
top_k = gr.Number(label="top_k", value=100, maximum=1000, minimum=1)
|
174 |
+
top_p = gr.Number(label="top_p", value=0.95, maximum=1.0, minimum=0.0)
|
175 |
+
# with gr.Box():
|
176 |
+
# db_type = gr.Dropdown(choices=["faiss", "chromadb"], label="Vector DB", value="faiss")
|
177 |
+
# # llm_name = gr.Dropdown(choices=["vicuna-7b-v1.3.ggmlv3.q4_1.bin",
|
178 |
+
# # "vicuna-7b-v1.3.ggmlv3.q5_1.bin"],
|
179 |
+
# # label="llm", value="vicuna-7b-v1.3.ggmlv3.q4_1.bin")
|
180 |
+
# b3 = gr.Button("update model")
|
181 |
+
# # b3.click(load_llm, inputs=llm_name, outputs=None)
|
182 |
+
with gr.Column():
|
183 |
+
# transcription = gr.Text(label="transcription")
|
184 |
+
Response = gr.Text(label="Réponse")
|
185 |
+
url = gr.Text(label="url source")
|
186 |
+
context = gr.Text(label="contexte (chunks)")
|
187 |
+
contextScore = gr.Text(label="contexte score (L2 distance)")
|
188 |
+
with gr.Box():
|
189 |
+
b2 = gr.Button("reconnaissace vocale")
|
190 |
+
b1 = gr.Button("search llm")
|
191 |
+
b1.click(LLM, inputs=[content, max_tokens, temp, k_chunks, top_k, top_p], #db_type
|
192 |
+
outputs=[Response, url, context, contextScore])
|
193 |
+
b2.click(stt, inputs=audio_path, outputs=content)
|
194 |
+
|
195 |
+
# with gr.Tab("gptq"):
|
196 |
+
# with gr.Row():
|
197 |
+
# with gr.Column():
|
198 |
+
# with gr.Box():
|
199 |
+
# content = gr.Text(label="Posez votre question")
|
200 |
+
# audio_path = gr.Audio(source="microphone",
|
201 |
+
# format="mp3",
|
202 |
+
# type="filepath",
|
203 |
+
# label="Posez votre question (Whisper-small)")
|
204 |
+
# with gr.Row():
|
205 |
+
# max_tokens = gr.Number(label="Max_tokens", value=100, maximum=1000, minimum=1)
|
206 |
+
# temp = gr.Number(label="Temperature", value=0.1, maximum=1.0, minimum=0.0)
|
207 |
+
# k_chunks = gr.Number(label="k_chunks", value=2, maximum=3, minimum=1)
|
208 |
+
# top_k = gr.Number(label="top_k", value=100, maximum=1000, minimum=1)
|
209 |
+
# top_p = gr.Number(label="top_p", value=0.95, maximum=1.0, minimum=0.0)
|
210 |
+
# with gr.Box():
|
211 |
+
# db_type = gr.Dropdown(choices=["faiss", "chromadb"], label="Vector DB", value="faiss")
|
212 |
+
# llm_name = gr.Dropdown(choices=["llama-2-7b.ggmlv3.q4_1.bin",
|
213 |
+
# "vicuna-7b-v1.3.ggmlv3.q4_1.bin"],
|
214 |
+
# label="llm", value="llama-2-7b.ggmlv3.q4_1.bin")
|
215 |
+
# b3 = gr.Button("update model")
|
216 |
+
# # b3.click(stt, inputs=llm_name, outputs=None)
|
217 |
+
# with gr.Column():
|
218 |
+
# # transcription = gr.Text(label="transcription")
|
219 |
+
# Response = gr.Text(label="Réponse")
|
220 |
+
# url = gr.Text(label="url source")
|
221 |
+
# context = gr.Text(label="contexte (chunks)")
|
222 |
+
# contextScore = gr.Text(label="contexte score (L2 distance)")
|
223 |
+
# with gr.Box():
|
224 |
+
# b2 = gr.Button("reconnaissace vocale")
|
225 |
+
# b1 = gr.Button("search llm")
|
226 |
+
# b1.click(LLM, inputs=[content, max_tokens, temp, k_chunks, top_k, top_p, db_type],
|
227 |
+
# outputs=[Response, url, context, contextScore])
|
228 |
+
# b2.click(stt, inputs=audio_path, outputs=content)
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
demo.launch(share=True, enable_queue=True, show_api=True)
|
build_db.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import xml.etree.ElementTree as ET
|
3 |
+
from langchain.document_loaders import UnstructuredURLLoader
|
4 |
+
from langchain.vectorstores import Chroma, FAISS
|
5 |
+
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
+
import chromadb
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
# works only with sitemap.xml url ONLY
|
12 |
+
|
13 |
+
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
+
|
15 |
+
#Params to edit
|
16 |
+
global Chunk_size, Chunk_overlap
|
17 |
+
Chunk_size = 2000
|
18 |
+
Chunk_overlap = 100
|
19 |
+
Sitemap_url = "https://www.malakoffhumanis.com/sitemap.xml"
|
20 |
+
|
21 |
+
def langchain_web_scraper(sitemap_url, chunk_size=1000, chunk_overlap=100):
|
22 |
+
"""
|
23 |
+
"""
|
24 |
+
# Fetch the sitemap.xml file
|
25 |
+
response = requests.get(sitemap_url)
|
26 |
+
tree = ET.fromstring(response.content)
|
27 |
+
# Extract URLs from sitemap
|
28 |
+
urls = []
|
29 |
+
for url in tree.findall("{http://www.sitemaps.org/schemas/sitemap/0.9}url"):
|
30 |
+
loc = url.find("{http://www.sitemaps.org/schemas/sitemap/0.9}loc").text
|
31 |
+
# if "" in loc:
|
32 |
+
urls.append(loc)
|
33 |
+
print("len(urls)", len(urls))
|
34 |
+
# scraping
|
35 |
+
loaders = UnstructuredURLLoader(urls=urls)
|
36 |
+
data = loaders.load()
|
37 |
+
|
38 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
39 |
+
chunk_overlap=chunk_overlap
|
40 |
+
# separators=[" ", "\n"]
|
41 |
+
)
|
42 |
+
|
43 |
+
documents = text_splitter.split_documents(data)
|
44 |
+
return documents
|
45 |
+
|
46 |
+
def store_vdb_faiss(documents=[], hf_embs=None, save_path="faiss_MHCOM"):
|
47 |
+
"""
|
48 |
+
"""
|
49 |
+
db = FAISS.from_documents(documents, hf_embs)
|
50 |
+
db.save_local(save_path)
|
51 |
+
|
52 |
+
def store_vdb_chroma(documents=[], hf_embs=None, save_path="chroma_MHCOM"):
|
53 |
+
"""
|
54 |
+
"""
|
55 |
+
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
56 |
+
DB_DIR = os.path.join(ABS_PATH, save_path)
|
57 |
+
|
58 |
+
client_settings = chromadb.config.Settings(
|
59 |
+
chroma_db_impl="duckdb+parquet",
|
60 |
+
persist_directory=DB_DIR,
|
61 |
+
anonymized_telemetry=False
|
62 |
+
)
|
63 |
+
vectorstore = Chroma(
|
64 |
+
collection_name="langchain_store",
|
65 |
+
embedding_function=hf_embs,
|
66 |
+
client_settings=client_settings,
|
67 |
+
persist_directory=DB_DIR,
|
68 |
+
)
|
69 |
+
vectorstore.add_documents(documents=documents, embedding=hf_embs)
|
70 |
+
vectorstore.persist()
|
71 |
+
|
72 |
+
|
73 |
+
def main():
|
74 |
+
print("scrapping website")
|
75 |
+
documents = langchain_web_scraper(sitemap_url=Sitemap_url,
|
76 |
+
chunk_size=Chunk_size,
|
77 |
+
chunk_overlap=Chunk_overlap)
|
78 |
+
#store in vector DB FAISS
|
79 |
+
print("load embeddings")
|
80 |
+
embeddings_model_name = "sentence-transformers/all-mpnet-base-v2"
|
81 |
+
hf_embs = HuggingFaceEmbeddings(model_name=embeddings_model_name,
|
82 |
+
model_kwargs={"device": "cuda"})
|
83 |
+
|
84 |
+
print("storing chunks in vector db")
|
85 |
+
store_vdb_faiss(documents=documents,
|
86 |
+
hf_embs=hf_embs,
|
87 |
+
save_path="faiss_MH_c{}_o{}".format(str(Chunk_size),
|
88 |
+
str(Chunk_overlap)))
|
89 |
+
|
90 |
+
# store_vdb_chroma(documents=documents,
|
91 |
+
# hf_embs=hf_embs,
|
92 |
+
# save_path="chroma_MH_c{}_o{}".format(str(Chunk_size),
|
93 |
+
# str(Chunk_overlap)))
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|
downld_models_local.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
|
3 |
+
repo_id = "TheBloke/zephyr-7B-beta-GGUF"
|
4 |
+
model_name = "zephyr-7b-beta.Q4_K_M.gguf"
|
5 |
+
|
6 |
+
local_dir="./"#"/mnt/ssd1/MH/AMINE/NLPBANK/localchatbot"
|
7 |
+
|
8 |
+
hf_hub_download(repo_id=repo_id,
|
9 |
+
filename=model_name,
|
10 |
+
local_dir=local_dir,
|
11 |
+
local_dir_use_symlinks=False)
|
faiss_MH_c2000_o100/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8560a800adb7152772c1dd3041e45e0c6842a89b3f8d4c8a5629596de225d519
|
3 |
+
size 11461677
|
faiss_MH_c2000_o100/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93516c28a21f6317fe230a2f92a21e6de8f6fb628a65297cd8fca8c569c95501
|
3 |
+
size 6453479
|
stt.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request
|
2 |
+
import json
|
3 |
+
import base64
|
4 |
+
from faster_whisper import WhisperModel
|
5 |
+
import tempfile
|
6 |
+
import os
|
7 |
+
|
8 |
+
def base64_to_wav(base64_data, save_path):
|
9 |
+
wav_data = base64.b64decode(base64_data)
|
10 |
+
with open(save_path, 'wb') as file:
|
11 |
+
file.write(wav_data)
|
12 |
+
|
13 |
+
app = Flask(__name__)
|
14 |
+
|
15 |
+
GPU = True
|
16 |
+
wspr = WhisperModel("small", device="cuda" if GPU else "cpu", compute_type="int8")
|
17 |
+
|
18 |
+
@app.route('/api', methods=['GET' ,'POST'])
|
19 |
+
def STT():
|
20 |
+
if request.method == 'POST':
|
21 |
+
result = {}
|
22 |
+
audio_data = request.data
|
23 |
+
|
24 |
+
# Create a unique filename in the temporary directory
|
25 |
+
temp_dir = tempfile.gettempdir()
|
26 |
+
temp_file = tempfile.NamedTemporaryFile(suffix='.mp3', dir=temp_dir, delete=False)
|
27 |
+
save_path = temp_file.name
|
28 |
+
# save_path = "temp.wav"
|
29 |
+
|
30 |
+
# save_path = "temp.wav"
|
31 |
+
base64_to_wav(audio_data, save_path)
|
32 |
+
|
33 |
+
segments, info = wspr.transcribe(save_path, beam_size=5, language="fr")
|
34 |
+
texts = [el.text.strip() for el in segments]
|
35 |
+
transcription = texts[0]
|
36 |
+
query = transcription
|
37 |
+
result["transcription"] = query
|
38 |
+
if os.path.exists(save_path):
|
39 |
+
os.remove(save_path)
|
40 |
+
return result
|
41 |
+
elif request.method == 'GET':
|
42 |
+
# Add your test logic here
|
43 |
+
return "API is working correctly!"
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
app.run(host='0.0.0.0', port=5566, debug=True)
|