sakares commited on
Commit
c3d7797
·
1 Parent(s): 1450336

init app template with streamlit. credit to GPT2 Indonesian https://huggingface.co/spaces/flax-community/gpt2-indonesian

Browse files
Files changed (5) hide show
  1. README.md +4 -4
  2. app.py +118 -0
  3. prompts.py +20 -0
  4. requirements.txt +4 -0
  5. start.sh +10 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Gpt2 Thai
3
- emoji: 💩
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
+ title: GPT2 Thai
3
+ emoji: 🐘
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from mtranslate import translate
4
+ from prompts import PROMPT_LIST
5
+ import streamlit as st
6
+ import random
7
+
8
+ headers = {}
9
+ MODELS = {
10
+ "GPT-2 Base": {
11
+ "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-base-thai"
12
+ }
13
+ }
14
+
15
+
16
+ def query(payload, model_name):
17
+ data = json.dumps(payload)
18
+ print("model url:", MODELS[model_name]["url"])
19
+ response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data)
20
+ return json.loads(response.content.decode("utf-8"))
21
+
22
+
23
+ def process(text: str,
24
+ model_name: str,
25
+ max_len: int,
26
+ temp: float,
27
+ top_k: int,
28
+ top_p: float):
29
+
30
+ payload = {
31
+ "inputs": text,
32
+ "parameters": {
33
+ "max_new_tokens": max_len,
34
+ "top_k": top_k,
35
+ "top_p": top_p,
36
+ "temperature": temp,
37
+ "repetition_penalty": 2.0,
38
+ },
39
+ "options": {
40
+ "use_cache": True,
41
+ }
42
+ }
43
+ return query(payload, model_name)
44
+
45
+ st.set_page_config(page_title="Thai GPT-2 Demo")
46
+
47
+ st.title("Thai GPT-2")
48
+
49
+ st.sidebar.subheader("Configurable parameters")
50
+
51
+ max_len = st.sidebar.text_input(
52
+ "Maximum length",
53
+ value=100,
54
+ help="The maximum length of the sequence to be generated."
55
+ )
56
+
57
+ temp = st.sidebar.slider(
58
+ "Temperature",
59
+ value=1.0,
60
+ min_value=0.1,
61
+ max_value=100.0,
62
+ help="The value used to module the next token probabilities."
63
+ )
64
+
65
+ top_k = st.sidebar.text_input(
66
+ "Top k",
67
+ value=50,
68
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
69
+ )
70
+
71
+ top_p = st.sidebar.text_input(
72
+ "Top p",
73
+ value=0.95,
74
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
75
+ )
76
+
77
+ do_sample = st.sidebar.selectbox('Sampling?', (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.")
78
+
79
+ st.markdown(
80
+ """Thai GPT-2 demo. Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/)."""
81
+ )
82
+
83
+ model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium']))
84
+
85
+ ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
86
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
87
+
88
+ if prompt == "Custom":
89
+ prompt_box = "Enter your text here"
90
+ else:
91
+ prompt_box = random.choice(PROMPT_LIST[prompt])
92
+
93
+ text = st.text_area("Enter text", prompt_box)
94
+
95
+ if st.button("Run"):
96
+ with st.spinner(text="Getting results..."):
97
+ st.subheader("Result")
98
+ print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
99
+ result = process(text=text,
100
+ model_name=model_name,
101
+ max_len=int(max_len),
102
+ temp=temp,
103
+ top_k=int(top_k),
104
+ top_p=float(top_p))
105
+
106
+ print("result:", result)
107
+ if "error" in result:
108
+ if type(result["error"]) is str:
109
+ st.write(f'{result["error"]}. Please try it again in about {result["estimated_time"]:.0f} seconds')
110
+ else:
111
+ if type(result["error"]) is list:
112
+ for error in result["error"]:
113
+ st.write(f'{error}')
114
+ else:
115
+ result = result[0]["generated_text"]
116
+ st.write(result.replace("\n", " \n"))
117
+ st.text("English translation")
118
+ st.write(translate(result, "en", "id").replace("\n", " \n"))
prompts.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_LIST = {
2
+ "บทสนทนาทั่วไป (conversation)": [
3
+ "สวัสดีตอนเช้า",
4
+ "สบายดีไหม",
5
+ "ขอบคุณ"
6
+ ],
7
+ "เรื่องสั้น (short story)": [
8
+ "เธอกับฉัน เราพบกันโดยบังเอิญที่ร้านกาแฟแห่งหนึ่ง",
9
+ "บ่ายสี่โมงแล้ว แสงอาทิตย์เริ่มจะอ่อนลงบ้าง",
10
+ "เธอใช้มือปาดน้ำตาที่ไหลลงมาตามใบหน้าเธอ"],
11
+ "ประวัติศาสตร์ (history)": [
12
+ "การปฏิวัติอุตสาหกรรมครั้งแรกซึ่งเริ่มในคริสต์ศตวรรษที่ 18",
13
+ "แนวคิดเรื่องเครื่องจักรที่คิดได้และสิ่งมีชีวิตเทียมนั้นมีมาตั้งแต่สมัยกรีกโบราณ",
14
+ "ช่วงต้นคริสต์ศตวรรษที่ 21 ปัญญาประดิษฐ์ประสบความสำเร็จอย่างสูง"],
15
+ "เนื้อเพลง (lyrics)": [
16
+ "โจ๊ะโจ๊ะ...มันทำลายสมองคน",
17
+ "รักของเธอมีจริงหรือเปล่า",
18
+ "ก็รู้ว่าฉันไม่มีความหมาย"
19
+ ]
20
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ requests==2.24.0
3
+ requests-toolbelt==0.9.1
4
+ mtranslate
start.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+
4
+ if [ "$DEBUG" = true ] ; then
5
+ echo 'Debugging - ON'
6
+ nodemon --exec streamlit run main.py
7
+ else
8
+ echo 'Debugging - OFF'
9
+ streamlit run main.py
10
+ fi