rasyosef commited on
Commit
c0a205f
1 Parent(s): b226335

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from threading import Thread
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
4
+
5
+ model_id = "rasyosef/gpt2-small-amharic-128-v3"
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(model_id)
9
+
10
+ gpt2_am = pipeline(
11
+ "text-generation",
12
+ model=model,
13
+ tokenizer=tokenizer,
14
+ pad_token_id=tokenizer.pad_token_id,
15
+ eos_token_id=tokenizer.eos_token_id
16
+ )
17
+
18
+ def generate(prompt):
19
+ prompt_length = len(tokenizer.tokenize(prompt))
20
+ if prompt_length >= 128:
21
+ yield prompt + "\n\nPrompt is too long. It needs to be less than 128 tokens."
22
+ else:
23
+ max_new_tokens = max(0, 128 - prompt_length)
24
+ streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=300.0)
25
+ thread = Thread(
26
+ target=gpt2_am,
27
+ kwargs={
28
+ "text_inputs": prompt,
29
+ "max_new_tokens": max_new_tokens,
30
+ "temperature": 0.8,
31
+ "do_sample": True,
32
+ "top_k": 8,
33
+ "top_p": 0.8,
34
+ "repetition_penalty": 1.25,
35
+ "streamer": streamer
36
+ })
37
+ thread.start()
38
+
39
+ generated_text = ""
40
+ for word in streamer:
41
+ generated_text += word
42
+ response = generated_text.strip()
43
+ yield response
44
+
45
+ with gr.Blocks() as demo:
46
+ gr.Markdown("""
47
+ # GPT2 Amharic
48
+ This is a demo for a smaller version of the gpt2 decoder transformer model pretrained for 1.5 days on `290 million` tokens of **Amharic** text. The context size of `gpt2-small-amharic` is 128 tokens.
49
+ """)
50
+
51
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", lines=4, interactive=True)
52
+ with gr.Row():
53
+ with gr.Column():
54
+ gen = gr.Button("Generate")
55
+ with gr.Column():
56
+ btn = gr.ClearButton([prompt])
57
+ gen.click(generate, inputs=[prompt], outputs=[prompt])
58
+ examples = gr.Examples(
59
+ examples=[
60
+ "የ አዲስ አበባ",
61
+ "በ ኢንግሊዝ ፕሪምየር ሊግ",
62
+ "ፕሬዚዳንት ዶናልድ ትራምፕ"
63
+ ],
64
+ inputs=[prompt],
65
+ )
66
+ demo.queue().launch(debug=True)