themanas021 commited on
Commit
2fbbb50
โ€ข
1 Parent(s): 365a9d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import io
4
+
5
+
6
+ # Designing the interface
7
+ st.title("๐Ÿ–ผ๏ธ Image Captioning Demo ๐Ÿ“")
8
+ st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
9
+
10
+ st.sidebar.markdown(
11
+ """
12
+ An image captioning model by combining ViT model with GPT2 model.
13
+ The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
14
+ framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
15
+ The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
16
+ The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
17
+ [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
18
+ """
19
+ )
20
+
21
+ with st.spinner('Loading and compiling ViT-GPT2 model ...'):
22
+ from model import *
23
+
24
+ random_image_id = get_random_image_id()
25
+
26
+ st.sidebar.title("Select a sample image")
27
+ sample_image_id = st.sidebar.selectbox(
28
+ "Please choose a sample image",
29
+ sample_image_ids
30
+ )
31
+
32
+ if st.sidebar.button("Random COCO 2017 (val) images"):
33
+ random_image_id = get_random_image_id()
34
+ sample_image_id = "None"
35
+
36
+ bytes_data = None
37
+ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
38
+ uploaded_file = st.file_uploader("Choose a file")
39
+ submitted = st.form_submit_button("Upload")
40
+ if submitted and uploaded_file is not None:
41
+ bytes_data = io.BytesIO(uploaded_file.getvalue())
42
+
43
+ if (bytes_data is None) and submitted:
44
+
45
+ st.write("No file is selected to upload")
46
+
47
+ else:
48
+
49
+ image_id = random_image_id
50
+ if sample_image_id != "None":
51
+ assert type(sample_image_id) == int
52
+ image_id = sample_image_id
53
+
54
+ sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
55
+ sample_path = os.path.join(sample_dir, sample_name)
56
+
57
+ if bytes_data is not None:
58
+ image = Image.open(bytes_data)
59
+ elif os.path.isfile(sample_path):
60
+ image = Image.open(sample_path)
61
+ else:
62
+ url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
63
+ image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ width, height = image.size
66
+ resized = image.resize(size=(width, height))
67
+ if height > 384:
68
+ width = int(width / height * 384)
69
+ height = 384
70
+ resized = resized.resize(size=(width, height))
71
+ width, height = resized.size
72
+ if width > 512:
73
+ width = 512
74
+ height = int(height / width * 512)
75
+ resized = resized.resize(size=(width, height))
76
+
77
+ if bytes_data is None:
78
+ st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
79
+ show = st.image(resized)
80
+ show.image(resized, '\n\nSelected Image')
81
+ resized.close()
82
+
83
+ # For newline
84
+ st.sidebar.write('\n')
85
+
86
+ with st.spinner('Generating image caption ...'):
87
+
88
+ caption = predict(image)
89
+
90
+ caption_en = caption
91
+ st.header(f'Predicted caption:\n\n')
92
+ st.subheader(caption_en)
93
+
94
+ st.sidebar.header("ViT-GPT2 predicts: ")
95
+ st.sidebar.write(f"{caption}")
96
+
97
+ image.close()