merve HF staff commited on
Commit
7e02e28
β€’
1 Parent(s): a56a44f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client
3
+ import os
4
+ import json
5
+
6
+
7
+ def generate_answer_pix2struct_base(image_path, question):
8
+ try:
9
+ client = Client("https://merve-pix2struct.hf.space/")
10
+ return client.predict(
11
+ image_path,
12
+ question,
13
+ fn_index=1
14
+ )
15
+ except Exception:
16
+ gr.Warning("The Pix2Struct Large Space is currently unavailable. Please try again later.")
17
+ return ""
18
+
19
+
20
+ def generate_answer(image_path, question, model_name, space_id):
21
+ try:
22
+ client = Client(f"https://{model_name}.hf.space/")
23
+ result = client.predict(image_path, question, api_name="/predict")
24
+ if result.endswith(".json"):
25
+ with open(result, "rb") as json_file:
26
+ output = json.loads(json_file.read())
27
+ if model_name == "TusharGoel-LayoutLM-DocVQA":
28
+ return output["label"]
29
+ else:
30
+ return output["answer"]
31
+ else:
32
+ return result
33
+ except Exception:
34
+ gr.Warning(f"The {model_name} Space is currently unavailable. Please try again later.")
35
+ return ""
36
+
37
+
38
+ def generate_answers(image_path, question):
39
+ answer_p2s_base = generate_answer_pix2struct_base(image_path, question)
40
+
41
+ answer_p2s_large = generate_answer(image_path, question, model_name = "akdeniz27-pix2struct-DocVQA", space_id = "Pix2Struct Large")
42
+
43
+ answer_layoutlm = generate_answer(image_path, question, model_name = "TusharGoel-LayoutLM-DocVQA", space_id = "LayoutLM DocVQA")
44
+
45
+ answer_donut = generate_answer(image_path, question, model_name = "nielsr-donut-docvqa", space_id = "Donut DocVQA")
46
+
47
+ return answer_p2s_base, answer_p2s_large, answer_layoutlm, answer_donut
48
+
49
+ examples = [["docvqa_example.png", "How many items are sold?"], ["document-question-answering-input.png", "What is the objective?"]]
50
+
51
+ title = "# Interactive demo: comparing document question answering (VQA) models"
52
+
53
+ css = """
54
+ #mkd {
55
+ height: 500px;
56
+ overflow: auto;
57
+ border: 1px solid #ccc;
58
+ }
59
+ """
60
+
61
+ with gr.Blocks(css=css) as demo:
62
+ gr.HTML("<h1><center>Compare Document Question Answering Models πŸ“„<center><h1>")
63
+ gr.HTML("<h3><center>Document question answering is the task of answering questions from documents in visual form. πŸ“”πŸ“•</h3>")
64
+ gr.HTML("<h3><center>To try this Space, simply upload documents and questions. </h3>")
65
+ gr.HTML("<h3><center>If prompted to wait and try again, please try again. This Space uses other Spaces as APIs, so it might take time to get those Spaces up and running if they're stopped. </h3>")
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ input_image = gr.Image(label = "Input Document", type="filepath")
70
+ question = gr.Textbox(label = "question")
71
+ run_button = gr.Button("Answer")
72
+ with gr.Column():
73
+ out_p2s_base = gr.Textbox(label="Answer generated by Pix2Struct Base")
74
+ out_p2s_large = gr.Textbox(label="Answer generated by Pix2Struct Large")
75
+ out_layoutlm = gr.Textbox(label="Answer generated by LayoutLM")
76
+ out_donut = gr.Textbox(label="Answer generated by Donut")
77
+
78
+
79
+ outputs = [
80
+ out_p2s_base,
81
+ out_p2s_large,
82
+ out_layoutlm,
83
+ out_donut,
84
+ ]
85
+
86
+ gr.Examples(
87
+ examples = [["docvqa_example.png", "How many items are sold?"],
88
+ ["document-question-answering-input.png", "What is the objective?"]],
89
+ inputs=[input_image, question],
90
+ outputs=outputs,
91
+ fn=generate_answers,
92
+ cache_examples=True
93
+ )
94
+
95
+
96
+
97
+ run_button.click(
98
+ fn=generate_answers,
99
+ inputs=[input_image,question],
100
+ outputs=outputs
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+ demo.queue().launch(debug=True)