Spaces:
Runtime error
Runtime error
import json | |
import requests | |
from mtranslate import translate | |
from prompts import PROMPT_LIST | |
import streamlit as st | |
import random | |
import fasttext | |
import SessionState | |
headers = {} | |
LOGO = "huggingwayang.png" | |
MODELS = { | |
"GPT-2 Small": { | |
"url": "https://api-inference.huggingface.co/models/flax-community/gpt2-small-indonesian" | |
}, | |
"GPT-2 Medium": { | |
"url": "https://api-inference.huggingface.co/models/flax-community/gpt2-medium-indonesian" | |
}, | |
"GPT-2 Small finetuned on Indonesian academic journals": { | |
"url": "https://api-inference.huggingface.co/models/Galuh/id-journal-gpt2" | |
}, | |
} | |
def get_image(text: str): | |
url = "https://wikisearch.uncool.ai/get_image/" | |
try: | |
payload = { | |
"text": text, | |
"image_width": 400 | |
} | |
data = json.dumps(payload) | |
response = requests.request("POST", url, headers=headers, data=data) | |
print(response.content) | |
image = json.loads(response.content.decode("utf-8"))["url"] | |
except: | |
image = "" | |
return image | |
def query(payload, model_name): | |
data = json.dumps(payload) | |
# print("model url:", MODELS[model_name]["url"]) | |
response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data) | |
return json.loads(response.content.decode("utf-8")) | |
def process(text: str, | |
model_name: str, | |
max_len: int, | |
temp: float, | |
top_k: int, | |
top_p: float): | |
payload = { | |
"inputs": text, | |
"parameters": { | |
"max_new_tokens": max_len, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temp, | |
"repetition_penalty": 2.0, | |
}, | |
"options": { | |
"use_cache": True, | |
} | |
} | |
return query(payload, model_name) | |
st.set_page_config(page_title="Indonesian GPT-2 Demo") | |
st.title("Indonesian GPT-2") | |
try: | |
token = st.secrets["flax_community_token"] | |
headers = {"Authorization": f"Bearer {token}"} | |
except FileNotFoundError: | |
print(f"Token is not found") | |
ft_model = fasttext.load_model('lid.176.ftz') | |
# Sidebar | |
st.sidebar.image(LOGO) | |
st.sidebar.subheader("Configurable parameters") | |
max_len = st.sidebar.number_input( | |
"Maximum length", | |
value=100, | |
help="The maximum length of the sequence to be generated." | |
) | |
temp = st.sidebar.slider( | |
"Temperature", | |
value=1.0, | |
min_value=0.0, | |
max_value=100.0, | |
help="The value used to module the next token probabilities." | |
) | |
top_k = st.sidebar.number_input( | |
"Top k", | |
value=50, | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
) | |
top_p = st.sidebar.number_input( | |
"Top p", | |
value=1.0, | |
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation." | |
) | |
st.markdown( | |
""" | |
This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and | |
[medium](https://huggingface.co/flax-community/gpt2-medium-indonesian) Indonesian GPT2 model | |
trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4) | |
and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the | |
[Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). | |
The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with | |
the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence | |
in the encoded passages using Facebook's Faiss. | |
""" | |
) | |
model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small finetuned on Indonesian academic journals'])) | |
if model_name in ["GPT-2 Small", "GPT-2 Medium"]: | |
prompt_group_name = "GPT-2" | |
elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]: | |
prompt_group_name = "Indonesian Journals" | |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None) | |
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"] | |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1) | |
# Update prompt | |
if session_state.prompt is None: | |
session_state.prompt = prompt | |
elif session_state.prompt is not None and (prompt != session_state.prompt): | |
session_state.prompt = prompt | |
session_state.prompt_box = None | |
session_state.text = None | |
else: | |
session_state.prompt = prompt | |
# Update prompt box | |
if session_state.prompt == "Custom": | |
session_state.prompt_box = "Enter your text here" | |
else: | |
if session_state.prompt is not None and session_state.prompt_box is None: | |
session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt]) | |
session_state.text = st.text_area("Enter text", session_state.prompt_box) | |
if st.button("Run"): | |
with st.spinner(text="Getting results..."): | |
lang_predictions, lang_probability = ft_model.predict(session_state.text.replace("\n", " "), k=3) | |
if "__label__id" in lang_predictions: | |
lang = "id" | |
text = session_state.text | |
else: | |
lang = lang_predictions[0].replace("__label__", "") | |
text = translate(session_state.text, "id", lang) | |
st.subheader("Result") | |
result = process(text=text, | |
model_name=model_name, | |
max_len=int(max_len), | |
temp=temp, | |
top_k=int(top_k), | |
top_p=float(top_p)) | |
# print("result:", result) | |
if "error" in result: | |
if type(result["error"]) is str: | |
st.write(f'{result["error"]}.', end=" ") | |
if "estimated_time" in result: | |
st.write(f'Please try it again in about {result["estimated_time"]:.0f} seconds') | |
else: | |
if type(result["error"]) is list: | |
for error in result["error"]: | |
st.write(f'{error}') | |
else: | |
result = result[0]["generated_text"] | |
st.write(result.replace("\n", " \n")) | |
st.text("Translation") | |
translation = translate(result, "en", "id") | |
if lang == "id": | |
st.write(translation.replace("\n", " \n")) | |
else: | |
st.write(translate(result, lang, "id").replace("\n", " \n")) | |
image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif" | |
image = get_image(translation.replace("\"", "'")) | |
if image is not "": | |
st.image(image, width=400) | |
else: | |
# display cat image if no image found | |
st.image(image_cat, width=400) | |
# Reset state | |
session_state.prompt = None | |
session_state.prompt_box = None | |
session_state.text = None | |