Spaces:
Runtime error
Runtime error
gchhablani
commited on
Commit
•
d74c94a
1
Parent(s):
c84beda
Add other generation parameters
Browse files
app.py
CHANGED
@@ -44,11 +44,11 @@ code_to_name = {
|
|
44 |
}
|
45 |
|
46 |
@st.cache
|
47 |
-
def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p):
|
48 |
lang_code = language_mapping[lang_code]
|
49 |
-
output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=
|
50 |
print(output_ids)
|
51 |
-
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=
|
52 |
return output_sequence
|
53 |
|
54 |
def read_markdown(path, parent="./sections/"):
|
@@ -72,6 +72,9 @@ st.write(
|
|
72 |
)
|
73 |
|
74 |
st.sidebar.title("Generation Parameters")
|
|
|
|
|
|
|
75 |
num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
|
76 |
temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
|
77 |
top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
|
@@ -150,7 +153,7 @@ lang_id = new_col2.selectbox(
|
|
150 |
sequence = ['']
|
151 |
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
|
152 |
with st.spinner("Generating Sequence..."):
|
153 |
-
sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p)
|
154 |
# print(sequence)
|
155 |
|
156 |
if sequence!=['']:
|
|
|
44 |
}
|
45 |
|
46 |
@st.cache
|
47 |
+
def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
|
48 |
lang_code = language_mapping[lang_code]
|
49 |
+
output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
|
50 |
print(output_ids)
|
51 |
+
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
|
52 |
return output_sequence
|
53 |
|
54 |
def read_markdown(path, parent="./sections/"):
|
|
|
72 |
)
|
73 |
|
74 |
st.sidebar.title("Generation Parameters")
|
75 |
+
max_length = st.sidebar.number_input("Max Length", min_value=16, max_value=128, value=64, step=1, help="The maximum length of sequence to be generated.")
|
76 |
+
do_sample = st.sidebar.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
|
77 |
+
top_k = st.sidebar.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
|
78 |
num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
|
79 |
temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
|
80 |
top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
|
|
|
153 |
sequence = ['']
|
154 |
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
|
155 |
with st.spinner("Generating Sequence..."):
|
156 |
+
sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
|
157 |
# print(sequence)
|
158 |
|
159 |
if sequence!=['']:
|