miwojc commited on
Commit
c5d0e38
1 Parent(s): d4fb97b

Create app_bak.py

Browse files
Files changed (1) hide show
  1. app_bak.py +147 -0
app_bak.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Resources
7
+ Solutions
8
+ Pricing
9
+
10
+ Space:
11
+ Flax Community's picture
12
+ flax-community
13
+ /
14
+ papuGaPT2 Copied
15
+ Runtime error
16
+ App
17
+ Files and versions
18
+ Settings
19
+ papuGaPT2
20
+ /
21
+ app.py
22
+ miwojc's picture
23
+ miwojc
24
+ Update app.py
25
+ d4fb97b
26
+ 2 minutes ago
27
+ raw
28
+ history
29
+ blame
30
+ edit
31
+ 3,870 Bytes
32
+ import json
33
+ import random
34
+ import requests
35
+ from mtranslate import translate
36
+ import streamlit as st
37
+ MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
38
+ PROMPT_LIST = {
39
+ "Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
40
+ "Cześć, mam na imię...": ["Cześć, mam na imię "],
41
+ "Największym polskim poetą był...": ["Największym polskim poetą był "],
42
+ }
43
+ def query(payload, model_url):
44
+ data = json.dumps(payload)
45
+ print("model url:", model_url)
46
+ response = requests.request(
47
+ "POST", model_url, headers={}, data=data
48
+ )
49
+ return json.loads(response.content.decode("utf-8"))
50
+ def process(
51
+ text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
52
+ ):
53
+ payload = {
54
+ "inputs": text,
55
+ "parameters": {
56
+ "max_new_tokens": max_len,
57
+ "top_k": top_k,
58
+ "top_p": top_p,
59
+ "temperature": temp,
60
+ "repetition_penalty": 2.0,
61
+ },
62
+ "options": {
63
+ "use_cache": True,
64
+ },
65
+ }
66
+ return query(payload, model_name)
67
+ # Page
68
+ st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
69
+ st.title("papuGaPT2 (Polish GPT-2")
70
+ # Sidebar
71
+ st.sidebar.subheader("Configurable parameters")
72
+ max_len = st.sidebar.number_input(
73
+ "Maximum length",
74
+ value=100,
75
+ help="The maximum length of the sequence to be generated.",
76
+ )
77
+ temp = st.sidebar.slider(
78
+ "Temperature",
79
+ value=1.0,
80
+ min_value=0.1,
81
+ max_value=100.0,
82
+ help="The value used to module the next token probabilities.",
83
+ )
84
+ top_k = st.sidebar.number_input(
85
+ "Top k",
86
+ value=10,
87
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
88
+ )
89
+ top_p = st.sidebar.number_input(
90
+ "Top p",
91
+ value=0.95,
92
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
93
+ )
94
+ do_sample = st.sidebar.selectbox(
95
+ "Sampling?",
96
+ (True, False),
97
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
98
+ )
99
+ # Body
100
+ st.markdown(
101
+ """
102
+ papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
103
+
104
+ The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
105
+ """
106
+ )
107
+ model_name = MODEL_URL
108
+ ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
109
+ prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
110
+ if prompt == "Custom":
111
+ prompt_box = "Enter your text here"
112
+ else:
113
+ prompt_box = random.choice(PROMPT_LIST[prompt])
114
+ text = st.text_area("Enter text", prompt_box)
115
+ if st.button("Run"):
116
+ with st.spinner(text="Getting results..."):
117
+ st.subheader("Result")
118
+ print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
119
+ result = process(
120
+ text=text,
121
+ model_name=model_name,
122
+ max_len=int(max_len),
123
+ temp=temp,
124
+ top_k=int(top_k),
125
+ top_p=float(top_p),
126
+ )
127
+ print("result:", result)
128
+ if "error" in result:
129
+ if type(result["error"]) is str:
130
+ st.write(f'{result["error"]}.', end=" ")
131
+ if "estimated_time" in result:
132
+ st.write(
133
+ f'Please try again in about {result["estimated_time"]:.0f} seconds.'
134
+ )
135
+ else:
136
+ if type(result["error"]) is list:
137
+ for error in result["error"]:
138
+ st.write(f"{error}")
139
+ else:
140
+ result = result[0]["generated_text"]
141
+ st.write(result.replace("\
142
+ ", " \
143
+ "))
144
+ st.text("English translation")
145
+ st.write(translate(result, "en", "es").replace("\
146
+ ", " \
147
+ "))