Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from mtranslate import translate | |
import requests | |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") | |
text_generator_api = 'https://cahya-indonesian-whisperer.hf.space/api/text-generator/v1' | |
text_generator_api_auth_token = os.getenv("TEXT_GENERATOR_API_AUTH_TOKEN", "") | |
def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha): | |
print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha) | |
headers = {'Authorization': 'Bearer ' + text_generator_api_auth_token} | |
data = { | |
"model_name": "bloomz-1b1-instruct", | |
"text": user_input, | |
"min_length": len(user_input) + 10, | |
"max_length": len(user_input) + 200, | |
"decoding_method": decoding_method, | |
"num_beams": num_beams, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"seed": -1, | |
"repetition_penalty": repetition_penalty, | |
"penalty_alpha": penalty_alpha | |
} | |
r = requests.post(text_generator_api, headers=headers, data=data) | |
if r.status_code == 200: | |
result = r.json() | |
answer = result["generated_text"] | |
user_input_en = translate(user_input, "en", "auto") | |
answer_en = translate(answer, "en", "auto") | |
return [(f"{user_input}\n", None), (answer, "")], \ | |
[(f"{user_input_en}\n", None), (answer_en, "")] | |
else: | |
return "Error: " + r.text | |
css = """ | |
#answer_id span {white-space: pre-line} | |
#answer_id span.label {display: none} | |
#answer_en span {white-space: pre-line} | |
#answer_en span.label {display: none} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
gr.Markdown("""## Bloomz-1b7-Instruct | |
We fine-tuned the BigScience model Bloomz-1b7 with cross-lingual instructions dataset. Some of the supported | |
languages are: English, Indonesian, Vietnam, Hindi, Spanish, French, and Chinese. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.inputs.Textbox(placeholder="", | |
label="Ask me something", | |
default="Will we ever cure cancer? Please answer in Chinese.") | |
decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling"], | |
default="Sampling", label="Decoding Method") | |
num_beams = gr.inputs.Slider(label="Number of beams for beam search", | |
default=1, minimum=1, maximum=10, step=1) | |
top_k = gr.inputs.Slider(label="Top K", | |
default=30, maximum=50, minimum=1, step=1) | |
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0) | |
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0) | |
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0) | |
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", | |
default=0.5, step=0.05, minimum=0.05, maximum=1.0) | |
with gr.Row(): | |
button_generate_story = gr.Button("Submit") | |
with gr.Column(): | |
# generated_answer = gr.Textbox() | |
generated_answer = gr.HighlightedText( | |
elem_id="answer_id", | |
label="Generated Text", | |
combine_adjacent=True, | |
css="#htext span {white-space: pre-line}", | |
).style(color_map={"": "blue", "-": "green"}) | |
generated_answer_en = gr.HighlightedText( | |
elem_id="answer_en", | |
label="Translation", | |
combine_adjacent=True, | |
).style(color_map={"": "blue", "-": "green"}) | |
with gr.Row(): | |
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_bloomz-1b1-instruct)") | |
button_generate_story.click(get_answer, | |
inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature, | |
repetition_penalty, penalty_alpha], | |
outputs=[generated_answer, generated_answer_en]) | |
demo.launch(enable_queue=False) |