gchhablani commited on
Commit
950c460
1 Parent(s): 7fe8d4e

Change layout and update sidebar

Browse files
Files changed (3) hide show
  1. Multilingual IC.svg +0 -0
  2. app.py +38 -31
  3. mic-logo.png +0 -0
Multilingual IC.svg ADDED
app.py CHANGED
@@ -44,9 +44,9 @@ code_to_name = {
44
  }
45
 
46
  @st.cache(persist=True)
47
- def generate_sequence(pixel_values, lang_code, num_beams):
48
  lang_code = language_mapping[lang_code]
49
- output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
@@ -56,7 +56,7 @@ def read_markdown(path, parent="./sections/"):
56
  return f.read()
57
 
58
 
59
- checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints?
60
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
 
62
  st.set_page_config(
@@ -70,12 +70,29 @@ st.write(
70
  "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
71
  )
72
 
73
- st.sidebar.title("Settings")
74
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
 
 
75
 
76
  with st.beta_expander("Usage"):
77
  st.markdown(read_markdown("usage.md"))
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  first_index = 20
80
  # Init Session State
81
  if state.image_file is None:
@@ -87,9 +104,9 @@ if state.image_file is None:
87
  image = plt.imread(image_path)
88
  state.image = image
89
 
90
- col1, col2 = st.beta_columns([6, 4])
91
 
92
- if col2.button("Get a random example"):
93
  sample = dummy_data.sample(1).reset_index()
94
  state.image_file = sample.loc[0, "image_file"]
95
  state.caption = sample.loc[0, "caption"].strip("- ")
@@ -99,40 +116,42 @@ if col2.button("Get a random example"):
99
  image = plt.imread(image_path)
100
  state.image = image
101
 
102
- col2.write("OR")
103
 
104
- uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
105
- if uploaded_file is not None:
106
- state.image_file = os.path.join("images", uploaded_file.name)
107
- state.image = np.array(Image.open(uploaded_file))
108
 
109
  transformed_image = get_transformed_image(state.image)
110
 
 
111
  # Display Image
112
- col1.image(state.image, use_column_width="auto")
 
113
 
114
  # Display Reference Caption
115
- col2.write("**Reference Caption**: " + state.caption)
116
- col2.markdown(
117
  f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
118
  )
119
 
120
  # Select Language
121
  options = list(code_to_name.keys())
122
- lang_id = col2.selectbox(
123
  "Language",
124
  index=options.index(state.lang_id),
125
  options=options,
126
  format_func=lambda x: code_to_name[x],
 
127
  )
128
- # Display Top-5 Predictions
129
  with st.spinner("Loading model..."):
130
  model = load_model(checkpoints[0])
131
-
132
  sequence = ['']
133
- if col2.button("Generate Caption"):
134
  with st.spinner("Generating Sequence..."):
135
- sequence = generate_sequence(transformed_image, lang_id, num_beams)
136
  # print(sequence)
137
 
138
  if sequence!=['']:
@@ -143,15 +162,3 @@ if sequence!=['']:
143
  st.write(
144
  "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
145
  )
146
- st.write(read_markdown("abstract.md"))
147
- st.write(read_markdown("caveats.md"))
148
- # st.write("# Methodology")
149
- # st.image(
150
- # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
151
- # )
152
- st.markdown(read_markdown("pretraining.md"))
153
- st.write(read_markdown("challenges.md"))
154
- st.write(read_markdown("social_impact.md"))
155
- st.write(read_markdown("references.md"))
156
- # st.write(read_markdown("checkpoints.md"))
157
- st.write(read_markdown("acknowledgements.md"))
 
44
  }
45
 
46
  @st.cache(persist=True)
47
+ def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p):
48
  lang_code = language_mapping[lang_code]
49
+ output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
 
56
  return f.read()
57
 
58
 
59
+ checkpoints = ["./ckpt/ckpt-17499"] # TODO: Maybe add more checkpoints?
60
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
 
62
  st.set_page_config(
 
70
  "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
71
  )
72
 
73
+ st.sidebar.title("Generation Parameters")
74
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
75
+ temperature = st.sidebar.select_slider(label="Temperature", options = np.arange(0.0,1.1, step=0.1), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
76
+ top_p = st.sidebar.select_slider(label = "Top-P", options = np.arange(0.0,1.1, step=0.1),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
77
 
78
  with st.beta_expander("Usage"):
79
  st.markdown(read_markdown("usage.md"))
80
 
81
+ with st.beta_expander("Article"):
82
+ st.write(read_markdown("abstract.md"))
83
+ st.write(read_markdown("caveats.md"))
84
+ # st.write("# Methodology")
85
+ # st.image(
86
+ # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
87
+ # )
88
+ st.markdown(read_markdown("pretraining.md"))
89
+ st.write(read_markdown("challenges.md"))
90
+ st.write(read_markdown("social_impact.md"))
91
+ st.write(read_markdown("references.md"))
92
+ # st.write(read_markdown("checkpoints.md"))
93
+ st.write(read_markdown("acknowledgements.md"))
94
+
95
+
96
  first_index = 20
97
  # Init Session State
98
  if state.image_file is None:
 
104
  image = plt.imread(image_path)
105
  state.image = image
106
 
107
+ # col1, col2 = st.beta_columns([6, 4])
108
 
109
+ if st.button("Get a random example", help="Get a random example from one of the seeded examples."):
110
  sample = dummy_data.sample(1).reset_index()
111
  state.image_file = sample.loc[0, "image_file"]
112
  state.caption = sample.loc[0, "caption"].strip("- ")
 
116
  image = plt.imread(image_path)
117
  state.image = image
118
 
119
+ # col2.write("OR")
120
 
121
+ # uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
122
+ # if uploaded_file is not None:
123
+ # state.image_file = os.path.join("images", uploaded_file.name)
124
+ # state.image = np.array(Image.open(uploaded_file))
125
 
126
  transformed_image = get_transformed_image(state.image)
127
 
128
+ new_col1, new_col2 = st.beta_columns([5,5])
129
  # Display Image
130
+ new_col1.image(state.image, use_column_width="always")
131
+
132
 
133
  # Display Reference Caption
134
+ new_col2.write("**Reference Caption**: " + state.caption)
135
+ new_col2.markdown(
136
  f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
137
  )
138
 
139
  # Select Language
140
  options = list(code_to_name.keys())
141
+ lang_id = new_col2.selectbox(
142
  "Language",
143
  index=options.index(state.lang_id),
144
  options=options,
145
  format_func=lambda x: code_to_name[x],
146
+ help="The language in which caption is to be generated."
147
  )
148
+
149
  with st.spinner("Loading model..."):
150
  model = load_model(checkpoints[0])
 
151
  sequence = ['']
152
+ if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
153
  with st.spinner("Generating Sequence..."):
154
+ sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p)
155
  # print(sequence)
156
 
157
  if sequence!=['']:
 
162
  st.write(
163
  "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
164
  )
 
 
 
 
 
 
 
 
 
 
 
 
mic-logo.png ADDED