Spaces:
Sleeping
Sleeping
File size: 3,196 Bytes
009f8e2 23c0953 afb1063 d10fd10 ad8550a d10fd10 f4b9a92 ad8550a 23c0953 6213036 23c0953 bab84e2 23c0953 73ca0b1 23c0953 5df676a 9b8cca7 5df676a 5065ff5 23c0953 bab84e2 23c0953 1811e61 23c0953 f4b9a92 23c0953 5aafc6e afb1063 23c0953 fa07e23 afb1063 90fd9d9 21b62df 90fd9d9 21b62df 23c0953 90fd9d9 fa07e23 dd8c189 5df676a 4fe4bcb 5065ff5 9b8cca7 5065ff5 afb1063 23c0953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/flan-t5-large-grammar-synthesis")
tokenizer = AutoTokenizer.from_pretrained("pszemraj/flan-t5-large-grammar-synthesis")
def correct_text(text, genConfig):
inputs = tokenizer.encode("" + text, return_tensors="pt")
outputs = model.generate(inputs, **genConfig.to_dict())
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
def respond(text, max_new_tokens, min_new_tokens, num_beams, num_beam_groups, temperature, top_k, top_p, no_repeat_ngram_size, guidance_scale, do_sample: bool):
config = GenerationConfig(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
temperature=float(temperature),
top_k=top_k,
top_p=float(top_p),
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=True,
do_sample=do_sample
)
if guidance_scale > 0:
config.guidance_scale = float(guidance_scale)
corrected = correct_text(text, config)
yield corrected
def update_prompt(prompt):
return prompt
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""# Grammar Correction App""")
prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
output_box = gr.Textbox()
# Sample prompts
with gr.Row():
samp1 = gr.Button("we shood buy an car")
samp2 = gr.Button("she is more taller")
samp3 = gr.Button("John and i saw a sheep over their.")
samp1.click(update_prompt, samp1, prompt_box)
samp2.click(update_prompt, samp2, prompt_box)
samp3.click(update_prompt, samp3, prompt_box)
submitBtn = gr.Button("Submit")
with gr.Accordion("Generation Parameters:", open=False):
max_tokens = gr.Slider(minimum=1, maximum=256, value=50, step=1, label="Max New Tokens")
min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
num_beams = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Num Beams")
beam_groups = gr.Slider(minimum=1, maximum=20, value=1, step=1, label="Num Beams Groups")
temperature = gr.Slider(minimum=0.1, maximum=100.0, value=0.7, step=0.1, label="Temperature")
top_k = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="Top-k")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
guideScale = gr.Slider(minimum=0.1, maximum=50.0, value=1.0, step=0.1, label="Guidance Scale")
no_repeat_ngram_size = gr.Slider(0, 20, value=0, step=1, label="Limit N-grams of given Size")
do_sample = gr.Checkbox(value=True, label="Do Sampling")
submitBtn.click(respond, [prompt_box, max_tokens, min_tokens, num_beams, beam_groups, temperature, top_k, top_p, no_repeat_ngram_size, guideScale, do_sample], output_box)
demo.launch() |