File size: 3,942 Bytes
db67a3a 411d537 de1029e 411d537 de1029e 411d537 db67a3a cf0cbe3 08447dc 4e9992d 10e5029 265de56 db67a3a de1029e db67a3a cf0cbe3 db67a3a 411d537 923c281 db67a3a 4e4e302 db67a3a cf0cbe3 4e4e302 cf0cbe3 4e4e302 411d537 db67a3a cf0cbe3 411d537 de1029e db67a3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import speech_recognition as sr
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig,BitsAndBytesConfig
import torch
import os
from openai import OpenAI
key = os.environ.get('OPENAI_API_KEY')
client = OpenAI(api_key=key)
Medical_finetunned_model = "truongghieu/deci-finetuned_Prj2"
answer_text = "This is an answer"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(Medical_finetunned_model, trust_remote_code=True)
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(Medical_finetunned_model, trust_remote_code=True, quantization_config=bnb_config)
else:
model = AutoModelForCausalLM.from_pretrained("truongghieu/deci-finetuned", trust_remote_code=True)
def generate_text(*args):
if args[0] == "":
return "Please input text"
generation_config = GenerationConfig(
penalty_alpha=args[1],
do_sample=args[2],
top_k=args[3],
temperature=args[4],
repetition_penalty=args[5],
max_new_tokens=args[6],
pad_token_id=tokenizer.eos_token_id
)
input_text = f'###Human : {args[0]}'
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
output_ids = model.generate(input_ids, generation_config=generation_config)
output_text = tokenizer.decode(output_ids[0])
return output_text
def gpt_generate(*args):
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": args[0]}],
temperature = args[4],
max_tokens = args[6],
)
return response.choices[0].message.content
def recognize_speech(audio_data):
# return text
audio_data = sr.AudioData(np.array(audio_data[1]), sample_rate=audio_data[0] , sample_width=2)
recognizer = sr.Recognizer()
try:
text = recognizer.recognize_google(audio_data)
return text
except sr.UnknownValueError:
return "Speech Recognition could not understand audio."
except sr.RequestError as e:
return f"Could not request results from Google Speech Recognition service; {e}"
with gr.Blocks() as demo:
with gr.Row():
inp = gr.Audio(type="numpy")
out_text_predict = gr.Textbox(label="Recognized Speech")
button = gr.Button("Recognize Speech" , size="lg")
button.click(recognize_speech, inp, out_text_predict)
with gr.Row():
with gr.Row():
penalty_alpha_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="penalty alpha",value=0.6)
do_sample_checkbox = gr.Checkbox(label="do sample",value=True)
top_k_slider = gr.Slider(minimum=0, maximum=10, step=1, label="top k", value=5)
with gr.Row():
temperature_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="temperature",value=0.5)
repetition_penalty_slider = gr.Slider(minimum=0, maximum=2, step=0.1, label="repetition penalty",value=1.0)
max_new_tokens_slider = gr.Slider(minimum=0, maximum=200, step=1, label="max new tokens",value=30)
with gr.Row():
out_answer = gr.Textbox(label="Answer")
button_answer = gr.Button("Answer")
button_answer.click(generate_text, [out_text_predict, penalty_alpha_slider, do_sample_checkbox, top_k_slider, temperature_slider, repetition_penalty_slider, max_new_tokens_slider], out_answer)
with gr.Row():
gpt_output = gr.Textbox(label="GPT-3.5 Turbo Output")
button_gpt = gr.Button("GPT-3.5 Answer")
button_gpt.click(gpt_generate,[out_text_predict, penalty_alpha_slider, do_sample_checkbox, top_k_slider, temperature_slider, repetition_penalty_slider, max_new_tokens_slider],gpt_output)
demo.launch() |