Spaces:
Runtime error
Runtime error
Stefan Dumitrescu
commited on
Commit
•
19c9e19
1
Parent(s):
2f0ed55
Update
Browse files
app.py
CHANGED
@@ -84,13 +84,10 @@ with col1:
|
|
84 |
|
85 |
st.markdown("**Step 2: Adjust specific text generation parameters**")
|
86 |
|
87 |
-
tab_greedy,
|
88 |
|
89 |
with tab_greedy:
|
90 |
-
st.
|
91 |
-
|
92 |
-
with tab_beamsearch:
|
93 |
-
num_beams = st.slider("Num beams", min_value=1, max_value=30, step=5, value=5)
|
94 |
|
95 |
with tab_sampling:
|
96 |
top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
|
@@ -104,12 +101,13 @@ with col1:
|
|
104 |
st.markdown("**Step 3: Adjust common text generation parameters**")
|
105 |
|
106 |
no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
|
107 |
-
temperature = st.slider("Temperature", value=1.0, min_value=0.
|
108 |
max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
|
109 |
|
110 |
st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
|
111 |
|
112 |
|
|
|
113 |
def update_prompt():
|
114 |
st.session_state['text'] = prompt
|
115 |
|
@@ -138,7 +136,7 @@ if button_greedy or button_sampling or button_typical:
|
|
138 |
if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
|
139 |
keep_last = 512 - max_length
|
140 |
print(f"keep last: {keep_last}")
|
141 |
-
input_ids, attention_mask = tokenized_text.input_ids[0][
|
142 |
previous_ids = tokenized_text.input_ids[0][:keep_last]
|
143 |
st.warning(f"kept last {keep_last}")
|
144 |
else:
|
@@ -149,7 +147,9 @@ if button_greedy or button_sampling or button_typical:
|
|
149 |
output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
|
150 |
|
151 |
if previous_ids is not None:
|
152 |
-
|
|
|
|
|
153 |
else:
|
154 |
new_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
155 |
|
@@ -199,8 +199,8 @@ text_element = col2.text_area('Text:', height=400, key="text")
|
|
199 |
col2.markdown("""---""")
|
200 |
col2.text("Statistics and details:")
|
201 |
if details != "":
|
202 |
-
col2.caption("
|
203 |
if tokenized_text is None:
|
204 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
205 |
tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
|
206 |
-
col2.caption(f"
|
|
|
84 |
|
85 |
st.markdown("**Step 2: Adjust specific text generation parameters**")
|
86 |
|
87 |
+
tab_greedy, tab_sampling, tab_typical = st.tabs(["Greedy", "Sampling", "Typical Sampling"])
|
88 |
|
89 |
with tab_greedy:
|
90 |
+
st.caption("Greedy decoding does not have any special parameters.")
|
|
|
|
|
|
|
91 |
|
92 |
with tab_sampling:
|
93 |
top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
|
|
|
101 |
st.markdown("**Step 3: Adjust common text generation parameters**")
|
102 |
|
103 |
no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
|
104 |
+
temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1)
|
105 |
max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
|
106 |
|
107 |
st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
|
108 |
|
109 |
|
110 |
+
|
111 |
def update_prompt():
|
112 |
st.session_state['text'] = prompt
|
113 |
|
|
|
136 |
if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
|
137 |
keep_last = 512 - max_length
|
138 |
print(f"keep last: {keep_last}")
|
139 |
+
input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:]
|
140 |
previous_ids = tokenized_text.input_ids[0][:keep_last]
|
141 |
st.warning(f"kept last {keep_last}")
|
142 |
else:
|
|
|
147 |
output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
|
148 |
|
149 |
if previous_ids is not None:
|
150 |
+
print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
|
151 |
+
print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True))
|
152 |
+
new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True)
|
153 |
else:
|
154 |
new_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
155 |
|
|
|
199 |
col2.markdown("""---""")
|
200 |
col2.text("Statistics and details:")
|
201 |
if details != "":
|
202 |
+
col2.caption(" Generation details: " + details)
|
203 |
if tokenized_text is None:
|
204 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
205 |
tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
|
206 |
+
col2.caption(f" Text length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.")
|