linh-truong commited on
Commit
e276af2
1 Parent(s): 506c323
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +51 -0
  3. requirements.txt +2 -0
  4. src/model.py +28 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *test*
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ mapper = {
4
+ "wikilingua": "64b53c9e04b2bfaeb2f0b38cb7712bcbd4755c3f",
5
+ "vietnews": "d8d516ad2c112b429155c3e6077182ae5fe5b33d"
6
+ }
7
+
8
+ if "model" not in st.session_state:
9
+ from src.model import Model
10
+ st.session_state.task="wikilingua"
11
+ model = Model(revision=mapper["wikilingua"])
12
+ st.session_state.model = model
13
+
14
+
15
+ st.set_page_config(page_title="ViT5 Reproduce", layout="wide")
16
+ hide_menu_style = """
17
+ <style>
18
+ footer {visibility: hidden;}
19
+ </style>
20
+ """
21
+ st.markdown(hide_menu_style, unsafe_allow_html= True)
22
+
23
+ with st.sidebar:
24
+ task = st.selectbox(label="Task", options=["wikilingua", "vietnews"])
25
+ if task != st.session_state.task:
26
+ from src.model import Model
27
+ st.session_state.task=task
28
+ model = Model(revision=mapper[task])
29
+
30
+ left, middle, right = st.columns([4,1,4])
31
+
32
+ left_container = left.container(border=True)
33
+ left_container.write(f"**Input**")
34
+ left_container.divider()
35
+ text = left_container.text_area(label="", height=512, label_visibility="hidden", max_chars=4096*5)
36
+
37
+ summary_button = middle.button("Summary ➩", type="primary", use_container_width=True)
38
+
39
+ right_container = right.container(border=True)
40
+ right_container.markdown(f"**Output**")
41
+ right_container.divider()
42
+
43
+ if summary_button:
44
+
45
+ output = st.session_state.model.inference(text=text)
46
+
47
+
48
+ st.session_state["output"] = output
49
+
50
+ if "output" in st.session_state:
51
+ right_container.text_area(label="", value=st.session_state["output"], height=512, label_visibility="hidden")
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ streamlit==1.35.0
2
+ transformers==4.41.0
src/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+
4
+ class Model:
5
+ def __init__(self, revision) -> None:
6
+
7
+ self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("truong-xuan-linh/vit5-reproduce", revision=revision)
9
+
10
+ def preprocess_function(self, text):
11
+ inputs = self.tokenizer(
12
+ text, max_length=1024, truncation=True, padding=True, return_tensors="pt"
13
+ )
14
+ return inputs
15
+
16
+ def inference(self, text):
17
+ max_target_length = 256
18
+ inputs = self.preprocess_function(text)
19
+ outputs = self.model.generate(
20
+ input_ids=inputs['input_ids'],
21
+ max_length=max_target_length,
22
+ attention_mask=inputs['attention_mask'],
23
+ )
24
+
25
+ with self.tokenizer.as_target_tokenizer():
26
+ outputs = [self.tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]
27
+
28
+ return outputs[0]