ColeGuion commited on
Commit
23c0953
1 Parent(s): 95f281d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
  # Load the model and tokenizer
5
  model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
@@ -39,19 +39,43 @@ def correct_text(text, max_length, min_length, max_new_tokens, min_new_tokens, n
39
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  yield corrected_text
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def update_prompt(prompt):
44
  return prompt
45
 
46
  # Create the Gradio interface
47
  with gr.Blocks() as demo:
48
- gr.Markdown(
49
- """
50
- # Grammar Correction App
51
- """)
52
- prompt_box = gr.Textbox(lines=2, placeholder="Enter your prompt here...")
53
  output_box = gr.Textbox()
54
- submitBtn = gr.Button("Submit")
55
 
56
  # Sample prompts
57
  with gr.Row():
@@ -62,7 +86,8 @@ with gr.Blocks() as demo:
62
  samp1.click(update_prompt, samp1, prompt_box)
63
  samp2.click(update_prompt, samp2, prompt_box)
64
  samp3.click(update_prompt, samp3, prompt_box)
65
-
 
66
 
67
  with gr.Accordion("Generation Parameters:", open=False):
68
  max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
@@ -74,8 +99,6 @@ with gr.Blocks() as demo:
74
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
75
 
76
 
77
- submitBtn.click(correct_text, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)
78
-
79
-
80
 
81
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
3
 
4
  # Load the model and tokenizer
5
  model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
 
39
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  yield corrected_text
41
 
42
+ def correct_text2(text, genConfig):
43
+ inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
44
+ outputs = model.generate(inputs, **genConfig.to_dict())
45
+
46
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+ yield corrected_text
48
+
49
+ def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
50
+ config = GenerationConfig(
51
+ max_length=max_length,
52
+ min_length=min_length,
53
+ num_beams=num_beams,
54
+ temperature=temperature,
55
+ top_p=top_p,
56
+ early_stopping=True,
57
+ do_sample=True
58
+ )
59
+
60
+ # Add max/min new tokens if they are there
61
+ if max_new_tokens > 0:
62
+ config.max_new_tokens = max_new_tokens
63
+ if min_new_tokens > 0:
64
+ config.min_new_tokens = min_new_tokens
65
+
66
+ corrected = correct_text2(text, config)
67
+ yield corrected
68
+
69
+
70
 
71
  def update_prompt(prompt):
72
  return prompt
73
 
74
  # Create the Gradio interface
75
  with gr.Blocks() as demo:
76
+ gr.Markdown("""# Grammar Correction App""")
77
+ prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
 
 
 
78
  output_box = gr.Textbox()
 
79
 
80
  # Sample prompts
81
  with gr.Row():
 
86
  samp1.click(update_prompt, samp1, prompt_box)
87
  samp2.click(update_prompt, samp2, prompt_box)
88
  samp3.click(update_prompt, samp3, prompt_box)
89
+
90
+ submitBtn = gr.Button("Submit")
91
 
92
  with gr.Accordion("Generation Parameters:", open=False):
93
  max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
 
99
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
100
 
101
 
102
+ submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)
 
 
103
 
104
+ demo.launch()