tayaee commited on
Commit
f46275c
1 Parent(s): 07374db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+
5
+ import streamlit as st
6
+
7
+ model_id = "google/codegemma-7b-it"
8
+
9
+
10
+ def strip_bos_eos(text_tagged):
11
+ m = re.match(r".*?(?<=<bos>)(.*)(?=<eos>).*?", text_tagged, flags=re.DOTALL)
12
+ text_stripped = m.group(1) if m else text_tagged
13
+ return text_stripped
14
+
15
+
16
+ @st.cache_resource
17
+ def load_models():
18
+ from dotenv import load_dotenv
19
+ from transformers import GemmaTokenizer, AutoModelForCausalLM
20
+ load_dotenv()
21
+ _token = os.environ["HF_TOKEN"]
22
+ _tokenizer = GemmaTokenizer.from_pretrained(model_id)
23
+ _model = AutoModelForCausalLM.from_pretrained(model_id)
24
+ return _token, _tokenizer, _model
25
+
26
+
27
+ def process(_input_text):
28
+ _token, _tokenizer, _model = load_models()
29
+ input_ids = _tokenizer(_input_text, return_tensors="pt")
30
+ _outputs = _model.generate(**input_ids, max_new_tokens=4092)
31
+ _output_text = strip_bos_eos(_tokenizer.decode(_outputs[0]))
32
+ return _output_text
33
+
34
+
35
+ if __name__ == '__main__':
36
+ load_models()
37
+ st.title(model_id)
38
+ input_text = st.text_input("Prompt")
39
+ if st.button("Submit"):
40
+ output_text = process(input_text)
41
+ st.write(output_text)
42
+