saitejad's picture
Update app.py
8569d86
import gradio as gr
import librosa
import torch
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
# Audio
checkpoint = "microsoft/speecht5_asr"
audio_processor = SpeechT5Processor.from_pretrained(checkpoint)
audio_model = SpeechT5ForSpeechToText.from_pretrained(checkpoint)
def process_audio(sampling_rate, waveform):
# convert from int16 to floating point
waveform = waveform / 32678.0
# convert to mono if the stereo
if len(waveform.shape) > 1:
waveform = librosa.to_mono(waveform.T)
# resample to 16 kHz if necessary
if sampling_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000)
# limit to 30 seconds
waveform = waveform[:16000*30]
# make PyTorch tensor
waveform = torch.tensor(waveform)
return waveform
def audio_to_text(audio, mic_audio=None):
# audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
if mic_audio is not None:
sampling_rate, waveform = mic_audio
elif audio is not None:
sampling_rate, waveform = audio
else:
return "(please provide audio)"
waveform = process_audio(sampling_rate, waveform)
inputs = audio_processor(audio=waveform, sampling_rate=16000, return_tensors="pt")
predicted_ids = audio_model.generate(**inputs, max_length=400)
transcription = audio_processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
# Text Generation
model_path= hf_hub_download(repo_id="TheBloke/Llama-2-7B-Chat-GGML", filename="llama-2-7b-chat.ggmlv3.q4_0.bin")
llm2 = Llama(model_path=model_path)
def generate(text):
system_message = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Give short, simple and direct answers"
prompt_template=f'''[INST] <<SYS>>
{system_message}
<</SYS>>
{text} [/INST]'''
chat_compl = llm2.create_completion(prompt=prompt_template, top_k=50, top_p=0.7, temperature=0.7, repeat_penalty=1.5)
return chat_compl['choices'][0]['text'].strip()
def audio_text_generate(audio):
audio_text = audio_to_text(audio)
generated_text = generate(audio_text)
# response = generated_text[generated_text.index("[/INST]")+7:].strip()
return audio_text, generated_text
demo = gr.Interface(fn=audio_text_generate,
inputs=gr.Audio(source="microphone"),
outputs=[gr.Text(label="Audio Text"), gr.Text(label="Generated Text")])
# examples=["https://samplelib.com/lib/preview/mp3/sample-3s.mp3"], cache_examples=True)
demo.launch()