Spaces:
Runtime error
Runtime error
File size: 4,209 Bytes
c5d0e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Resources
Solutions
Pricing
Space:
Flax Community's picture
flax-community
/
papuGaPT2 Copied
Runtime error
App
Files and versions
Settings
papuGaPT2
/
app.py
miwojc's picture
miwojc
Update app.py
d4fb97b
2 minutes ago
raw
history
blame
edit
3,870 Bytes
import json
import random
import requests
from mtranslate import translate
import streamlit as st
MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
PROMPT_LIST = {
"Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
"Cześć, mam na imię...": ["Cześć, mam na imię "],
"Największym polskim poetą był...": ["Największym polskim poetą był "],
}
def query(payload, model_url):
data = json.dumps(payload)
print("model url:", model_url)
response = requests.request(
"POST", model_url, headers={}, data=data
)
return json.loads(response.content.decode("utf-8"))
def process(
text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
):
payload = {
"inputs": text,
"parameters": {
"max_new_tokens": max_len,
"top_k": top_k,
"top_p": top_p,
"temperature": temp,
"repetition_penalty": 2.0,
},
"options": {
"use_cache": True,
},
}
return query(payload, model_name)
# Page
st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
st.title("papuGaPT2 (Polish GPT-2")
# Sidebar
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.number_input(
"Maximum length",
value=100,
help="The maximum length of the sequence to be generated.",
)
temp = st.sidebar.slider(
"Temperature",
value=1.0,
min_value=0.1,
max_value=100.0,
help="The value used to module the next token probabilities.",
)
top_k = st.sidebar.number_input(
"Top k",
value=10,
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
top_p = st.sidebar.number_input(
"Top p",
value=0.95,
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
)
do_sample = st.sidebar.selectbox(
"Sampling?",
(True, False),
help="Whether or not to use sampling; use greedy decoding otherwise.",
)
# Body
st.markdown(
"""
papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
"""
)
model_name = MODEL_URL
ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
if prompt == "Custom":
prompt_box = "Enter your text here"
else:
prompt_box = random.choice(PROMPT_LIST[prompt])
text = st.text_area("Enter text", prompt_box)
if st.button("Run"):
with st.spinner(text="Getting results..."):
st.subheader("Result")
print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
result = process(
text=text,
model_name=model_name,
max_len=int(max_len),
temp=temp,
top_k=int(top_k),
top_p=float(top_p),
)
print("result:", result)
if "error" in result:
if type(result["error"]) is str:
st.write(f'{result["error"]}.', end=" ")
if "estimated_time" in result:
st.write(
f'Please try again in about {result["estimated_time"]:.0f} seconds.'
)
else:
if type(result["error"]) is list:
for error in result["error"]:
st.write(f"{error}")
else:
result = result[0]["generated_text"]
st.write(result.replace("\
", " \
"))
st.text("English translation")
st.write(translate(result, "en", "es").replace("\
", " \
"))
|