Spaces:
Runtime error
Runtime error
app file
Browse files
app.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ref https://github.com/ezzcodeezzlife/dalle2-in-python
|
2 |
+
# Ref https://towardsdatascience.com/speech-to-text-with-openais-whisper-53d5cea9005e
|
3 |
+
# Ref https://python.plainenglish.io/creating-an-awesome-web-app-with-python-and-streamlit-728fe100cf7
|
4 |
+
import logging
|
5 |
+
import logging.handlers
|
6 |
+
import queue
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
import urllib.request
|
10 |
+
from collections import deque
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import List
|
13 |
+
# import whisper
|
14 |
+
import av
|
15 |
+
import numpy as np
|
16 |
+
import pydub
|
17 |
+
import streamlit as st
|
18 |
+
from tqdm import tqdm
|
19 |
+
from streamlit_webrtc import WebRtcMode, webrtc_streamer
|
20 |
+
|
21 |
+
from dalle2 import Dalle2
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
HERE = Path(__file__).parent
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# Initialize the OpenAI API with your API key
|
29 |
+
import openai
|
30 |
+
from openai import OpenAI
|
31 |
+
api_key = "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"
|
32 |
+
client = OpenAI(api_key=api_key)
|
33 |
+
|
34 |
+
prompt = """I am a doctor, I would like you to check my prescription:
|
35 |
+
medical history: Hypertension, Type 2 Diabetes, and Asthma.
|
36 |
+
symptoms: Persistent cough, fever, and fatigue.
|
37 |
+
My prescription: Lisinopril 10mg daily, Metformin 500mg twice daily, and Albuterol as needed for asthma attacks.
|
38 |
+
Drug contexts:
|
39 |
+
- Lisinopril: Ingredients: ACE inhibitor. Adverse effects: Dizziness, dry cough, elevated blood potassium levels.
|
40 |
+
- Metformin: Ingredients: Oral antihyperglycemic agent. Adverse effects: Stomach upset, diarrhea, low blood sugar.
|
41 |
+
- Albuterol: Ingredients: Bronchodilator. Adverse effects: Tremors, nervousness, increased heart rate.
|
42 |
+
|
43 |
+
Please answer the following questions in concise point form, taking into account the provided drug context:
|
44 |
+
- Possible interactions between prescribed drugs?
|
45 |
+
- Adverse effect of given drugs that are specifically related to patient’s pre-existing conditions and medical history?
|
46 |
+
|
47 |
+
At the end of your answer, evaluate the level of dangerousness of this treatment, based on interactions and adverse effects. Dangerousness is categorized as: LOW, MEDIUM, HIGH
|
48 |
+
Your answer should look like this:
|
49 |
+
`
|
50 |
+
* interactions:
|
51 |
+
- <interaction 1>
|
52 |
+
- <interaction 2>
|
53 |
+
- ...
|
54 |
+
|
55 |
+
* adverse effects:
|
56 |
+
- <adverse effect 1>
|
57 |
+
- <adverse effect 2>
|
58 |
+
- ...`
|
59 |
+
|
60 |
+
* dangerousness: <LOW / MEDIUM / HIGH>
|
61 |
+
|
62 |
+
Note that you don't have to include any interactions or adverse effect, only those that are necessary.
|
63 |
+
"""
|
64 |
+
def get_drug_info_string(drug_names):
|
65 |
+
# Make the drug_to_info dictionary into a string with each line of the form drug: info
|
66 |
+
drug_info_string = ""
|
67 |
+
for drug in drug_names:
|
68 |
+
info = search_openfda_drug(drug)
|
69 |
+
drug_info_string += drug + ": " + str(trim_openfda_response(search_openfda_drug(drug))) + "\r\n"
|
70 |
+
return drug_info_string
|
71 |
+
import requests
|
72 |
+
|
73 |
+
def trim_openfda_response(json_response):
|
74 |
+
"""Trim the openFDA JSON response to include only specific fields.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
- json_response (dict): The raw JSON response from the openFDA API.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
- dict: A trimmed version of the JSON response.
|
81 |
+
"""
|
82 |
+
|
83 |
+
# List of desired fields
|
84 |
+
desired_fields = [
|
85 |
+
"spl_product_data_elements",
|
86 |
+
"boxed_warning",
|
87 |
+
"contraindications",
|
88 |
+
"drug_interactions",
|
89 |
+
"adverse_reactions",
|
90 |
+
"warnings"
|
91 |
+
]
|
92 |
+
|
93 |
+
trimmed_response = {}
|
94 |
+
|
95 |
+
# Check if results are present in the response
|
96 |
+
if 'results' in json_response:
|
97 |
+
for field in desired_fields:
|
98 |
+
if field in json_response['results'][0]:
|
99 |
+
trimmed_response[field] = json_response['results'][0][field]
|
100 |
+
|
101 |
+
return trimmed_response
|
102 |
+
|
103 |
+
def search_openfda_drug(drug_name):
|
104 |
+
"""Search for a drug in the openFDA database.
|
105 |
+
|
106 |
+
Parameters:
|
107 |
+
- drug_name (str): The name of the drug to search for.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
- dict: The JSON response from the openFDA API containing drug information, or None if there's an error.
|
111 |
+
"""
|
112 |
+
|
113 |
+
base_url = "https://api.fda.gov/drug/label.json"
|
114 |
+
query = f"?search=openfda.generic_name:{drug_name}&limit=1"
|
115 |
+
|
116 |
+
try:
|
117 |
+
response = requests.get(base_url + query)
|
118 |
+
|
119 |
+
# Check for successful request
|
120 |
+
if response.status_code == 200:
|
121 |
+
return response.json()
|
122 |
+
|
123 |
+
except requests.RequestException:
|
124 |
+
# If any request-related exception occurs, simply return None
|
125 |
+
print(f"Error encountered searching for drug {drug_name} with code {response.status_code}.")
|
126 |
+
|
127 |
+
return None
|
128 |
+
|
129 |
+
def ask_gpt(question, model="gpt-3.5-turbo"):
|
130 |
+
"""
|
131 |
+
Query the GPT-3.5 Turbo model with a given question.
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
- question (str): The input question or prompt for the model.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- str: The model's response.
|
138 |
+
"""
|
139 |
+
|
140 |
+
response = client.chat.completions.create(
|
141 |
+
model=model,
|
142 |
+
# model="gpt-4-vision-preview",
|
143 |
+
messages=[
|
144 |
+
{"role": "system", "content": "You are a knowledgeable medical database designed to provide concise and direct answers to medical questions."},
|
145 |
+
{"role": "user", "content": question}
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
# print(response.choices[0])
|
150 |
+
# return response.choices.message['content']
|
151 |
+
return response.choices[0].message.content
|
152 |
+
|
153 |
+
def parse_gpt(question):
|
154 |
+
"""
|
155 |
+
Query the GPT-3.5 Turbo model with a given question.
|
156 |
+
|
157 |
+
Parameters:
|
158 |
+
- question (str): The input question or prompt for the model.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
- str: The model's response.
|
162 |
+
"""
|
163 |
+
|
164 |
+
# 1 parse text to replace critical information in the prompt
|
165 |
+
# 2 send parsed text in OpenAPI
|
166 |
+
|
167 |
+
response = openai.ChatCompletion.create(
|
168 |
+
model="gpt-3.5-turbo",
|
169 |
+
messages=[
|
170 |
+
{"role": "system", "content": "You are a knowledgeable medical database designed to provide concise and direct answers to medical questions."},
|
171 |
+
{"role": "user", "content": question}
|
172 |
+
]
|
173 |
+
)
|
174 |
+
|
175 |
+
return response.choices[0].message['content']
|
176 |
+
|
177 |
+
# This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501
|
178 |
+
def download_file(url, download_to: Path, expected_size=None):
|
179 |
+
# Don't download the file twice.
|
180 |
+
# (If possible, verify the download using the file length.)
|
181 |
+
if download_to.exists():
|
182 |
+
if expected_size:
|
183 |
+
if download_to.stat().st_size == expected_size:
|
184 |
+
return
|
185 |
+
else:
|
186 |
+
st.info(f"{url} is already downloaded.")
|
187 |
+
if not st.button("Download again?"):
|
188 |
+
return
|
189 |
+
|
190 |
+
download_to.parent.mkdir(parents=True, exist_ok=True)
|
191 |
+
|
192 |
+
# These are handles to two visual elements to animate.
|
193 |
+
weights_warning, progress_bar = None, None
|
194 |
+
try:
|
195 |
+
weights_warning = st.warning("Downloading %s..." % url)
|
196 |
+
progress_bar = st.progress(0)
|
197 |
+
with open(download_to, "wb") as output_file:
|
198 |
+
with urllib.request.urlopen(url) as response:
|
199 |
+
length = int(response.info()["Content-Length"])
|
200 |
+
counter = 0.0
|
201 |
+
MEGABYTES = 2.0 ** 20.0
|
202 |
+
while True:
|
203 |
+
data = response.read(8192)
|
204 |
+
if not data:
|
205 |
+
break
|
206 |
+
counter += len(data)
|
207 |
+
output_file.write(data)
|
208 |
+
|
209 |
+
# We perform animation by overwriting the elements.
|
210 |
+
weights_warning.warning(
|
211 |
+
"Downloading %s... (%6.2f/%6.2f MB)"
|
212 |
+
% (url, counter / MEGABYTES, length / MEGABYTES)
|
213 |
+
)
|
214 |
+
progress_bar.progress(min(counter / length, 1.0))
|
215 |
+
# Finally, we remove these visual elements by calling .empty().
|
216 |
+
finally:
|
217 |
+
if weights_warning is not None:
|
218 |
+
weights_warning.empty()
|
219 |
+
if progress_bar is not None:
|
220 |
+
progress_bar.empty()
|
221 |
+
|
222 |
+
|
223 |
+
def main():
|
224 |
+
st.header("openFDA Medical Records Evaluation")
|
225 |
+
st.markdown(
|
226 |
+
"""
|
227 |
+
This demo app is using [DeepSpeech](https://github.com/mozilla/DeepSpeech),
|
228 |
+
an open speech-to-text engine.
|
229 |
+
|
230 |
+
A pre-trained model released with
|
231 |
+
[v0.9.3](https://github.com/mozilla/DeepSpeech/releases/tag/v0.9.3),
|
232 |
+
trained on American English is being served.
|
233 |
+
"""
|
234 |
+
)
|
235 |
+
|
236 |
+
# https://github.com/mozilla/DeepSpeech/releases/tag/v0.9.3
|
237 |
+
MODEL_URL = "https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.pbmm" # noqa
|
238 |
+
LANG_MODEL_URL = "https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.scorer" # noqa
|
239 |
+
MODEL_LOCAL_PATH = HERE / "models/deepspeech-0.9.3-models.pbmm"
|
240 |
+
LANG_MODEL_LOCAL_PATH = HERE / "models/deepspeech-0.9.3-models.scorer"
|
241 |
+
|
242 |
+
#download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=188915987)
|
243 |
+
#download_file(LANG_MODEL_URL, LANG_MODEL_LOCAL_PATH, expected_size=953363776)
|
244 |
+
|
245 |
+
lm_alpha = 0.931289039105002
|
246 |
+
lm_beta = 1.1834137581510284
|
247 |
+
beam = 100
|
248 |
+
|
249 |
+
medical_text_page = "Medical text evaluation" # summarize notes and identify risk from notes (useful for change in doctors)
|
250 |
+
voice_to_text_page = "Voice to medical text" # use voice to text and identify risk, could be use in case
|
251 |
+
image_to_text_page = "Image to medical text" # use image to text and identify risk, could be use in case
|
252 |
+
all_in_one_page = "All modalities" # use all modalities
|
253 |
+
sound_only_page = "Sound only (sendonly)"
|
254 |
+
with_video_page = "With video (sendrecv)"
|
255 |
+
text_only_page = "Text only for DALLE2"
|
256 |
+
app_mode = st.selectbox(
|
257 |
+
"Choose the app mode",
|
258 |
+
# [sound_only_page, with_video_page, text_only_page, medical_text_page]
|
259 |
+
[medical_text_page, voice_to_text_page, image_to_text_page, all_in_one_page]
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
if app_mode == sound_only_page:
|
265 |
+
app_sst(
|
266 |
+
str(MODEL_LOCAL_PATH), str(LANG_MODEL_LOCAL_PATH), lm_alpha, lm_beta, beam
|
267 |
+
)
|
268 |
+
elif app_mode == with_video_page:
|
269 |
+
app_sst_with_video(
|
270 |
+
str(MODEL_LOCAL_PATH), str(LANG_MODEL_LOCAL_PATH), lm_alpha, lm_beta, beam
|
271 |
+
)
|
272 |
+
elif app_mode == medical_text_page:
|
273 |
+
form = st.form(key='my-form')
|
274 |
+
text = form.text_input('Medical text description')
|
275 |
+
submit = form.form_submit_button('Submit')
|
276 |
+
|
277 |
+
st.write('Press submit to evaluate medical notes')
|
278 |
+
|
279 |
+
if submit:
|
280 |
+
# res = parse_gpt(text + "Organize the answers in 3 parts, first is pre-existing conditions, second is symptoms, third is prescriptions. Sample output for drugs should be the end of the answer as DRUG_NAMES: <drug 1>, <drug 2>, <drug 3>...")
|
281 |
+
parsed_notes = ask_gpt(f"""
|
282 |
+
Please parse the following medical note in point form, without losing any important information:
|
283 |
+
`{text}`
|
284 |
+
|
285 |
+
your answer should look like:
|
286 |
+
`**Patient's medical history**:
|
287 |
+
- <point 1>
|
288 |
+
- <point 2>
|
289 |
+
- ...
|
290 |
+
|
291 |
+
**Patient's symptoms**:
|
292 |
+
- <point 1>
|
293 |
+
- <point 2>
|
294 |
+
- ...
|
295 |
+
|
296 |
+
**Prescription**:
|
297 |
+
- ...
|
298 |
+
|
299 |
+
DRUGS: <drug 1>, <drug 2>, ...
|
300 |
+
`
|
301 |
+
Please be reminded to give the generic names for the drugs
|
302 |
+
""")
|
303 |
+
st.write(parsed_notes)
|
304 |
+
# Extract the drugs portion from the notes
|
305 |
+
drug_line = [line for line in parsed_notes.split("\n") if line.startswith("DRUGS:")][0]
|
306 |
+
|
307 |
+
# Strip the "DRUGS: " prefix and split the drugs by ", "
|
308 |
+
drugs = drug_line.replace("DRUGS: ", "").strip().split(", ")
|
309 |
+
|
310 |
+
# Go to FDA
|
311 |
+
drug_info_string = get_drug_info_string(drugs)
|
312 |
+
# st.write(drug_info_string)
|
313 |
+
|
314 |
+
# #
|
315 |
+
risk = ask_gpt(f"""I am a doctor, I would like you to check my prescription:
|
316 |
+
{parsed_notes}
|
317 |
+
|
318 |
+
Drug contexts:
|
319 |
+
{drug_info_string}
|
320 |
+
|
321 |
+
Please answer the following questions in concise point form, taking into account the provided drug context:
|
322 |
+
- Possible interactions between prescribed drugs?
|
323 |
+
- Adverse effect of given drugs, only answer those that are specifically related to patient’s pre-existing conditions and symptoms?
|
324 |
+
|
325 |
+
At the end of your answer, evaluate the level of dangerousness of this treatment, based on interactions and adverse effects that are specific to the patient. Dangerousness is categorized as: LOW, MEDIUM, HIGH
|
326 |
+
Your answer should look like this (you should include the * where specified):
|
327 |
+
`
|
328 |
+
* **INTERACTIONS**:
|
329 |
+
- <interaction 1>
|
330 |
+
- <interaction 2>
|
331 |
+
- ...
|
332 |
+
|
333 |
+
* **ADVERSE EFFECTS**:
|
334 |
+
- <adverse effect 1>
|
335 |
+
- <adverse effect 2>
|
336 |
+
- ...`
|
337 |
+
|
338 |
+
* **DANGEROUSNESS**: <LOW / MEDIUM / HIGH>
|
339 |
+
|
340 |
+
Note that you don't have to include any interactions or adverse effect, only those that are necessary.
|
341 |
+
""", model = 'gpt-3.5-turbo-16k')
|
342 |
+
# st.write(res)
|
343 |
+
st.write(risk)
|
344 |
+
|
345 |
+
elif app_mode == text_only_page:
|
346 |
+
form = st.form(key='my-form')
|
347 |
+
text = form.text_input('Image description')
|
348 |
+
submit = form.form_submit_button('Submit')
|
349 |
+
|
350 |
+
st.write('Press submit to generate image')
|
351 |
+
|
352 |
+
if submit:
|
353 |
+
app_sst_dalle2(text)
|
354 |
+
|
355 |
+
# form = st.form(key='my_form')
|
356 |
+
# text = form.text_input(label='Image Description')
|
357 |
+
# submit_button = form.form_submit_button(label='Submit')
|
358 |
+
# if submit_button:
|
359 |
+
# app_sst_dalle2(form.text)
|
360 |
+
|
361 |
+
#text = st.text_input('Image description')
|
362 |
+
#if st.form_submit_button('Generate') == True:
|
363 |
+
# app_sst_dalle2(text)
|
364 |
+
|
365 |
+
|
366 |
+
def app_sst_dalle2(text):
|
367 |
+
dalle = Dalle2("sess-TotC46rSs5pbqdXTRy75cr81ynLJALwa2b3rdxeh")
|
368 |
+
#generations = dalle.generate(text)
|
369 |
+
file_paths = dalle.generate_and_download(text)
|
370 |
+
print(file_paths)
|
371 |
+
#generations = dalle.generate_amount(text, 8) # Every generation has batch size 4 -> amount % 4 == 0 works best
|
372 |
+
for file in file_paths:
|
373 |
+
image = Image.open(file)
|
374 |
+
st.image(image, caption=text)
|
375 |
+
|
376 |
+
def app_sst(model_path: str, lm_path: str, lm_alpha: float, lm_beta: float, beam: int):
|
377 |
+
webrtc_ctx = webrtc_streamer(
|
378 |
+
key="speech-to-text",
|
379 |
+
mode=WebRtcMode.SENDONLY,
|
380 |
+
audio_receiver_size=1024,
|
381 |
+
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
|
382 |
+
media_stream_constraints={"video": False, "audio": True},
|
383 |
+
)
|
384 |
+
|
385 |
+
status_indicator = st.empty()
|
386 |
+
|
387 |
+
if not webrtc_ctx.state.playing:
|
388 |
+
return
|
389 |
+
|
390 |
+
status_indicator.write("Loading...")
|
391 |
+
text_output = st.empty()
|
392 |
+
stream = None
|
393 |
+
|
394 |
+
while True:
|
395 |
+
if webrtc_ctx.audio_receiver:
|
396 |
+
if stream is None:
|
397 |
+
from deepspeech import Model
|
398 |
+
# https://github.com/openai/whisper
|
399 |
+
# model = whisper.load_model(“large”)
|
400 |
+
|
401 |
+
model = Model(model_path)
|
402 |
+
model.enableExternalScorer(lm_path)
|
403 |
+
model.setScorerAlphaBeta(lm_alpha, lm_beta)
|
404 |
+
model.setBeamWidth(beam)
|
405 |
+
|
406 |
+
stream = model.createStream()
|
407 |
+
|
408 |
+
status_indicator.write("Model loaded.")
|
409 |
+
|
410 |
+
sound_chunk = pydub.AudioSegment.empty()
|
411 |
+
try:
|
412 |
+
audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
|
413 |
+
except queue.Empty:
|
414 |
+
time.sleep(0.1)
|
415 |
+
status_indicator.write("No frame arrived.")
|
416 |
+
continue
|
417 |
+
|
418 |
+
status_indicator.write("Running. Say something!")
|
419 |
+
|
420 |
+
for audio_frame in audio_frames:
|
421 |
+
sound = pydub.AudioSegment(
|
422 |
+
data=audio_frame.to_ndarray().tobytes(),
|
423 |
+
sample_width=audio_frame.format.bytes,
|
424 |
+
frame_rate=audio_frame.sample_rate,
|
425 |
+
channels=len(audio_frame.layout.channels),
|
426 |
+
)
|
427 |
+
sound_chunk += sound
|
428 |
+
|
429 |
+
if len(sound_chunk) > 0:
|
430 |
+
sound_chunk = sound_chunk.set_channels(1).set_frame_rate(
|
431 |
+
model.sampleRate()
|
432 |
+
)
|
433 |
+
buffer = np.array(sound_chunk.get_array_of_samples())
|
434 |
+
stream.feedAudioContent(buffer)
|
435 |
+
text = stream.intermediateDecode()
|
436 |
+
text_output.markdown(f"**Text:** {text}")
|
437 |
+
else:
|
438 |
+
status_indicator.write("AudioReciver is not set. Abort.")
|
439 |
+
break
|
440 |
+
|
441 |
+
|
442 |
+
def app_sst_with_video(
|
443 |
+
model_path: str, lm_path: str, lm_alpha: float, lm_beta: float, beam: int
|
444 |
+
):
|
445 |
+
frames_deque_lock = threading.Lock()
|
446 |
+
frames_deque: deque = deque([])
|
447 |
+
|
448 |
+
async def queued_audio_frames_callback(
|
449 |
+
frames: List[av.AudioFrame],
|
450 |
+
) -> av.AudioFrame:
|
451 |
+
with frames_deque_lock:
|
452 |
+
frames_deque.extend(frames)
|
453 |
+
|
454 |
+
# Return empty frames to be silent.
|
455 |
+
new_frames = []
|
456 |
+
for frame in frames:
|
457 |
+
input_array = frame.to_ndarray()
|
458 |
+
new_frame = av.AudioFrame.from_ndarray(
|
459 |
+
np.zeros(input_array.shape, dtype=input_array.dtype),
|
460 |
+
layout=frame.layout.name,
|
461 |
+
)
|
462 |
+
new_frame.sample_rate = frame.sample_rate
|
463 |
+
new_frames.append(new_frame)
|
464 |
+
|
465 |
+
return new_frames
|
466 |
+
|
467 |
+
webrtc_ctx = webrtc_streamer(
|
468 |
+
key="speech-to-text-w-video",
|
469 |
+
mode=WebRtcMode.SENDRECV,
|
470 |
+
queued_audio_frames_callback=queued_audio_frames_callback,
|
471 |
+
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
|
472 |
+
media_stream_constraints={"video": True, "audio": True},
|
473 |
+
)
|
474 |
+
|
475 |
+
status_indicator = st.empty()
|
476 |
+
|
477 |
+
if not webrtc_ctx.state.playing:
|
478 |
+
return
|
479 |
+
|
480 |
+
status_indicator.write("Loading...")
|
481 |
+
text_output = st.empty()
|
482 |
+
stream = None
|
483 |
+
|
484 |
+
while True:
|
485 |
+
if webrtc_ctx.state.playing:
|
486 |
+
if stream is None:
|
487 |
+
from deepspeech import Model
|
488 |
+
|
489 |
+
model = Model(model_path)
|
490 |
+
model.enableExternalScorer(lm_path)
|
491 |
+
model.setScorerAlphaBeta(lm_alpha, lm_beta)
|
492 |
+
model.setBeamWidth(beam)
|
493 |
+
|
494 |
+
stream = model.createStream()
|
495 |
+
|
496 |
+
status_indicator.write("Model loaded.")
|
497 |
+
|
498 |
+
sound_chunk = pydub.AudioSegment.empty()
|
499 |
+
|
500 |
+
audio_frames = []
|
501 |
+
with frames_deque_lock:
|
502 |
+
while len(frames_deque) > 0:
|
503 |
+
frame = frames_deque.popleft()
|
504 |
+
audio_frames.append(frame)
|
505 |
+
|
506 |
+
if len(audio_frames) == 0:
|
507 |
+
time.sleep(0.1)
|
508 |
+
status_indicator.write("No frame arrived.")
|
509 |
+
continue
|
510 |
+
|
511 |
+
status_indicator.write("Running. Say something!")
|
512 |
+
|
513 |
+
for audio_frame in audio_frames:
|
514 |
+
sound = pydub.AudioSegment(
|
515 |
+
data=audio_frame.to_ndarray().tobytes(),
|
516 |
+
sample_width=audio_frame.format.bytes,
|
517 |
+
frame_rate=audio_frame.sample_rate,
|
518 |
+
channels=len(audio_frame.layout.channels),
|
519 |
+
)
|
520 |
+
sound_chunk += sound
|
521 |
+
|
522 |
+
if len(sound_chunk) > 0:
|
523 |
+
sound_chunk = sound_chunk.set_channels(1).set_frame_rate(
|
524 |
+
model.sampleRate()
|
525 |
+
)
|
526 |
+
buffer = np.array(sound_chunk.get_array_of_samples())
|
527 |
+
stream.feedAudioContent(buffer)
|
528 |
+
text = stream.intermediateDecode()
|
529 |
+
text_output.markdown(f"**Text:** {text}")
|
530 |
+
else:
|
531 |
+
status_indicator.write("Stopped.")
|
532 |
+
break
|
533 |
+
|
534 |
+
# a raccoon astronaut with the cosmos reflecting on the glass of his helmet dreaming of the stars
|
535 |
+
def add_bg_from_url():
|
536 |
+
st.markdown(
|
537 |
+
f"""
|
538 |
+
<style>
|
539 |
+
.stApp {{
|
540 |
+
background-image: url("https://i.redd.it/zung2u9zryb91.png");
|
541 |
+
background-attachment: fixed;
|
542 |
+
background-size: cover
|
543 |
+
}}
|
544 |
+
</style>
|
545 |
+
""",
|
546 |
+
unsafe_allow_html=True
|
547 |
+
)
|
548 |
+
|
549 |
+
#add_bg_from_url()
|
550 |
+
|
551 |
+
|
552 |
+
if __name__ == "__main__":
|
553 |
+
import os
|
554 |
+
|
555 |
+
DEBUG = os.environ.get("DEBUG", "false").lower() not in ["false", "no", "0"]
|
556 |
+
|
557 |
+
logging.basicConfig(
|
558 |
+
format="[%(asctime)s] %(levelname)7s from %(name)s in %(pathname)s:%(lineno)d: "
|
559 |
+
"%(message)s",
|
560 |
+
force=True,
|
561 |
+
)
|
562 |
+
|
563 |
+
logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)
|
564 |
+
|
565 |
+
st_webrtc_logger = logging.getLogger("streamlit_webrtc")
|
566 |
+
st_webrtc_logger.setLevel(logging.DEBUG)
|
567 |
+
|
568 |
+
fsevents_logger = logging.getLogger("fsevents")
|
569 |
+
fsevents_logger.setLevel(logging.WARNING)
|
570 |
+
|
571 |
+
main()
|
572 |
+
|