truongghieu commited on
Commit
cf0cbe3
1 Parent(s): 9c2b4b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig,B
5
  import torch
6
 
7
  Medical_finetunned_model = "truongghieu/deci-finetuned_Prj2"
8
- question_text = "This is a question"
9
  answer_text = "This is an answer"
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -19,17 +18,24 @@ else:
19
  model = AutoModelForCausalLM.from_pretrained("truongghieu/deci-finetuned", trust_remote_code=True)
20
 
21
 
22
- generation_config = GenerationConfig(
23
- penalty_alpha=0.6,
24
- do_sample=True,
25
- top_k=3,
26
- temperature=0.5,
27
- repetition_penalty=1.2,
28
- max_new_tokens=50,
29
- pad_token_id=tokenizer.eos_token_id
30
- )
31
 
32
- def generate_text(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  input_text = f'###Human: \"{text}\"'
34
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
35
  output_ids = model.generate(input_ids, generation_config=generation_config)
@@ -45,8 +51,7 @@ def recognize_speech(audio_data):
45
  recognizer = sr.Recognizer()
46
  try:
47
  text = recognizer.recognize_google(audio_data)
48
- question_text = text
49
- return f"Recognized Speech: {text}"
50
 
51
  except sr.UnknownValueError:
52
  return "Speech Recognition could not understand audio."
@@ -65,13 +70,22 @@ def recognize_speech(audio_data):
65
 
66
  with gr.Blocks() as demo:
67
  with gr.Row():
68
- gr.Label("Speech Recognition")
69
  inp = gr.Audio(type="numpy")
70
  out_text_predict = gr.Textbox(label="Recognized Speech")
71
- button = gr.Button("Recognize Speech")
72
  button.click(recognize_speech, inp, out_text_predict)
 
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
  out_answer = gr.Textbox(label="Answer")
75
- button_answer = gr.Button("Generate Answer")
76
- button_answer.click(generate_text, out_text_predict, out_answer)
77
  demo.launch()
 
5
  import torch
6
 
7
  Medical_finetunned_model = "truongghieu/deci-finetuned_Prj2"
 
8
  answer_text = "This is an answer"
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
18
  model = AutoModelForCausalLM.from_pretrained("truongghieu/deci-finetuned", trust_remote_code=True)
19
 
20
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def generate_text(*args):
23
+ if args[0] == "":
24
+ return "Please input text"
25
+ args[1] = 0.6 if args[1] == 0 else args[1]
26
+ args[3] = 5 if args[3] == 0 else args[3]
27
+ args[4] = 0.5 if args[4] == 0 else args[4]
28
+ args[5] = 50 if args[5] == 0 else args[5]
29
+
30
+ generation_config = GenerationConfig(
31
+ penalty_alpha=args[1],
32
+ do_sample=args[2],
33
+ top_k=args[3],
34
+ temperature=args[4],
35
+ repetition_penalty=args[5],
36
+ max_new_tokens=args[6],
37
+ pad_token_id=tokenizer.eos_token_id
38
+ )
39
  input_text = f'###Human: \"{text}\"'
40
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
41
  output_ids = model.generate(input_ids, generation_config=generation_config)
 
51
  recognizer = sr.Recognizer()
52
  try:
53
  text = recognizer.recognize_google(audio_data)
54
+ return text
 
55
 
56
  except sr.UnknownValueError:
57
  return "Speech Recognition could not understand audio."
 
70
 
71
  with gr.Blocks() as demo:
72
  with gr.Row():
73
+ inp_text = gr.Textbox(label="Input Text")
74
  inp = gr.Audio(type="numpy")
75
  out_text_predict = gr.Textbox(label="Recognized Speech")
76
+ button = gr.Button("Recognize Speech" , size="sm")
77
  button.click(recognize_speech, inp, out_text_predict)
78
+ with gr.Row():
79
+ with gr.Row():
80
+ penalty_alpha_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="penalty alpha")
81
+ do_sample_checkbox = gr.Checkbox(label="do sample")
82
+ top_k_slider = gr.Slider(minimum=0, maximum=10, step=1, label="top k")
83
+ with gr.Row():
84
+ temperature_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="temperature")
85
+ repetition_penalty_slider = gr.Slider(minimum=0, maximum=2, step=0.1, label="repetition penalty")
86
+ max_new_tokens_slider = gr.Slider(minimum=0, maximum=100, step=1, label="max new tokens")
87
  with gr.Row():
88
  out_answer = gr.Textbox(label="Answer")
89
+ button_answer = gr.Button("Answer")
90
+ button_answer.click(generate_text, [out_text_predict, penalty_alpha_slider, do_sample_checkbox, top_k_slider, temperature_slider, repetition_penalty_slider, max_new_tokens_slider], out_answer)
91
  demo.launch()