Spaces:
Runtime error
Runtime error
import gradio as gr | |
import librosa | |
import torch | |
from transformers import AutoTokenizer, pipeline, logging | |
from transformers import SpeechT5Processor, SpeechT5ForSpeechToText | |
# Audio | |
checkpoint = "microsoft/speecht5_asr" | |
audio_processor = SpeechT5Processor.from_pretrained(checkpoint) | |
audio_model = SpeechT5ForSpeechToText.from_pretrained(checkpoint) | |
def process_audio(sampling_rate, waveform): | |
# convert from int16 to floating point | |
waveform = waveform / 32678.0 | |
# convert to mono if the stereo | |
if len(waveform.shape) > 1: | |
waveform = librosa.to_mono(waveform.T) | |
# resample to 16 kHz if necessary | |
if sampling_rate != 16000: | |
waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000) | |
# limit to 30 seconds | |
waveform = waveform[:16000*30] | |
# make PyTorch tensor | |
waveform = torch.tensor(waveform) | |
return waveform | |
def audio_to_text(audio, mic_audio=None): | |
# audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels)) | |
if mic_audio is not None: | |
sampling_rate, waveform = mic_audio | |
elif audio is not None: | |
sampling_rate, waveform = audio | |
else: | |
return "(please provide audio)" | |
waveform = process_audio(sampling_rate, waveform) | |
inputs = audio_processor(audio=waveform, sampling_rate=16000, return_tensors="pt") | |
predicted_ids = audio_model.generate(**inputs, max_length=400) | |
transcription = audio_processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
return transcription[0] | |
# Text Generation | |
model_path= hf_hub_download(repo_id="TheBloke/Llama-2-7B-Chat-GGML", filename="llama-2-7b-chat.ggmlv3.q4_0.bin") | |
llm2 = Llama(model_path=model_path) | |
def generate(text): | |
system_message = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Give short, simple and direct answers" | |
prompt_template=f'''[INST] <<SYS>> | |
{system_message} | |
<</SYS>> | |
{text} [/INST]''' | |
chat_compl = llm2.create_completion(prompt=prompt_template, top_k=50, top_p=0.7, temperature=0.7, repeat_penalty=1.5) | |
return chat_compl['choices'][0]['text'].strip() | |
def audio_text_generate(audio): | |
audio_text = audio_to_text(audio) | |
generated_text = generate(audio_text) | |
# response = generated_text[generated_text.index("[/INST]")+7:].strip() | |
return audio_text, generated_text | |
demo = gr.Interface(fn=audio_text_generate, | |
inputs=gr.Audio(source="microphone"), | |
outputs=[gr.Text(label="Audio Text"), gr.Text(label="Generated Text")]) | |
# examples=["https://samplelib.com/lib/preview/mp3/sample-3s.mp3"], cache_examples=True) | |
demo.launch() |