JaMe76 commited on
Commit
317c295
1 Parent(s): a30356a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -45
app.py CHANGED
@@ -1,25 +1,23 @@
1
  import os
2
  os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
 
 
4
  import deepdoctection as dd
5
-
6
- from deepdoctection.extern.model import ModelProfile
7
- from deepdoctection.analyzer.dd import build_analyzer, _auto_select_lib_and_device, _maybe_copy_config_to_cache
8
- from deepdoctection.utils.metacfg import set_config_by_yaml
9
- from deepdoctection.dataflow import DataFromList
10
 
11
  import gradio as gr
12
 
13
- _DD_ONE = "deepdoctection/configs/conf_dd_one.yaml"
14
- _TESSERACT = "deepdoctection/configs/conf_tesseract.yaml"
15
 
16
- dd.ModelCatalog.register("layout/model_final_inf_only.pt",ModelProfile(
 
 
 
17
  name="layout/model_final_inf_only.pt",
18
  description="Detectron2 layout detection model trained on private datasets",
19
  config="dd/d2/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
20
  size=[274632215],
21
  tp_model=False,
22
- hf_repo_id=os.environ.get("HF_REPO"),
23
  hf_model_name="model_final_inf_only.pt",
24
  hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
25
  categories={"1": dd.LayoutType.text,
@@ -28,53 +26,211 @@ dd.ModelCatalog.register("layout/model_final_inf_only.pt",ModelProfile(
28
  "4": dd.LayoutType.table,
29
  "5": dd.LayoutType.figure},
30
  ))
31
-
32
- def get_space_dd_analyzer():
33
- # get a dd analyzer with a special layout model
34
- lib, device = _auto_select_lib_and_device()
35
- dd_one_config_path = _maybe_copy_config_to_cache(_DD_ONE)
36
- _maybe_copy_config_to_cache(_TESSERACT)
37
 
38
- # Set up of the configuration and logging
39
- cfg = set_config_by_yaml(dd_one_config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  cfg.freeze(freezed=False)
42
- cfg.LIB = lib
43
- cfg.DEVICE = device
44
- cfg.TAB = True
45
- cfg.TAB_REF = True
46
- cfg.OCR = True
47
- cfg.LANG = None
48
- cfg.WEIGHTS.D2LAYOUT = "layout/model_final_inf_only.pt"
49
  cfg.freeze()
50
-
51
- return build_analyzer(cfg)
52
-
53
-
54
 
55
- def analyze_image(img):
56
- # creating an image object and passing to the analyzer by using dataflows
57
- image = dd.Image(file_name="input.png", location="")
58
- image.image = img[:,:,::-1]
59
 
60
- df = DataFromList(lst=[image])
 
 
61
 
62
- analyzer = get_space_dd_analyzer()
 
63
 
64
- df = analyzer.analyze(dataset_dataflow=df)
65
- df.reset_state()
66
- dp = next(iter(df))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  out = dp.as_dict()
68
  out.pop("image")
69
-
70
- return dp.viz(show_table_structure=False), out
71
 
72
- inputs = [gr.inputs.Image(type='numpy', label="Original Image")]
73
- outputs = [gr.outputs.Image(type="numpy", label="Output Image"), gr.JSON()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- title = "Deepdoctection - A Document AI Package"
76
- description = "Demonstration of layout analysis and output of a document page. This demo uses the deepdoctection analyzer with Tesseract's OCR engine. Models detect text, titles, tables, figures and lists as well as table cells. Based on the layout it determines reading order and generates an JSON output."
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- examples = [['sample_1.jpg'],['sample_2.png']]
79
 
80
- gr.Interface(analyze_image, inputs, outputs, title=title, description=description, examples=examples).launch()
 
1
  import os
2
  os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
 
4
+ from os import getcwd, path, environ
5
  import deepdoctection as dd
6
+ from deepdoctection.dataflow.serialize import DataFromList
 
 
 
 
7
 
8
  import gradio as gr
9
 
 
 
10
 
11
+ _DD_ONE = "conf_dd_one.yaml"
12
+ _DETECTIONS = ["table", "ocr"]
13
+
14
+ dd.ModelCatalog.register("layout/model_final_inf_only.pt",dd.ModelProfile(
15
  name="layout/model_final_inf_only.pt",
16
  description="Detectron2 layout detection model trained on private datasets",
17
  config="dd/d2/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
18
  size=[274632215],
19
  tp_model=False,
20
+ hf_repo_id=environ.get("HF_REPO"),
21
  hf_model_name="model_final_inf_only.pt",
22
  hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
23
  categories={"1": dd.LayoutType.text,
 
26
  "4": dd.LayoutType.table,
27
  "5": dd.LayoutType.figure},
28
  ))
 
 
 
 
 
 
29
 
30
+ # Set up of the configuration and logging. Models are globally defined, so that they are not re-loaded once the input
31
+ # updates
32
+ cfg = dd.set_config_by_yaml(path.join(getcwd(),_DD_ONE))
33
+ cfg.freeze(freezed=False)
34
+ cfg.DEVICE = "cpu"
35
+ cfg.freeze()
36
+
37
+ # layout detector
38
+ layout_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2LAYOUT)
39
+ layout_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2LAYOUT)
40
+ categories_layout = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2LAYOUT).categories
41
+ assert categories_layout is not None
42
+ assert layout_weights_path is not None
43
+ d_layout = dd.D2FrcnnDetector(layout_config_path, layout_weights_path, categories_layout, device=cfg.DEVICE)
44
+
45
+ # cell detector
46
+ cell_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2CELL)
47
+ cell_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2CELL)
48
+ categories_cell = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2CELL).categories
49
+ assert categories_cell is not None
50
+ d_cell = dd.D2FrcnnDetector(cell_config_path, cell_weights_path, categories_cell, device=cfg.DEVICE)
51
+
52
+ # row/column detector
53
+ item_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2ITEM)
54
+ item_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2ITEM)
55
+ categories_item = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2ITEM).categories
56
+ assert categories_item is not None
57
+ d_item = dd.D2FrcnnDetector(item_config_path, item_weights_path, categories_item, device=cfg.DEVICE)
58
+
59
+ # word detector
60
+ det = dd.DoctrTextlineDetector()
61
+
62
+ # text recognizer
63
+ rec = dd.DoctrTextRecognizer()
64
+
65
+
66
+ def build_gradio_analyzer(table, table_ref, ocr):
67
+ """Building the Detectron2/DocTr analyzer based on the given config"""
68
 
69
  cfg.freeze(freezed=False)
70
+ cfg.TAB = table
71
+ cfg.TAB_REF = table_ref
72
+ cfg.OCR = ocr
 
 
 
 
73
  cfg.freeze()
 
 
 
 
74
 
75
+ pipe_component_list = []
76
+ layout = dd.ImageLayoutService(d_layout, to_image=True, crop_image=True)
77
+ pipe_component_list.append(layout)
 
78
 
79
+ if cfg.TAB:
80
+ cell = dd.SubImageLayoutService(d_cell, dd.LayoutType.table, {1: 6}, True)
81
+ pipe_component_list.append(cell)
82
 
83
+ item = dd.SubImageLayoutService(d_item, dd.LayoutType.table, {1: 7, 2: 8}, True)
84
+ pipe_component_list.append(item)
85
 
86
+ table_segmentation = dd.TableSegmentationService(
87
+ cfg.SEGMENTATION.ASSIGNMENT_RULE,
88
+ cfg.SEGMENTATION.IOU_THRESHOLD_ROWS
89
+ if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
90
+ else cfg.SEGMENTATION.IOA_THRESHOLD_ROWS,
91
+ cfg.SEGMENTATION.IOU_THRESHOLD_COLS
92
+ if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
93
+ else cfg.SEGMENTATION.IOA_THRESHOLD_COLS,
94
+ cfg.SEGMENTATION.FULL_TABLE_TILING,
95
+ cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
96
+ cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
97
+ )
98
+ pipe_component_list.append(table_segmentation)
99
+
100
+ if cfg.TAB_REF:
101
+ table_segmentation_refinement = dd.TableSegmentationRefinementService()
102
+ pipe_component_list.append(table_segmentation_refinement)
103
+
104
+ if cfg.OCR:
105
+ d_layout_text = dd.ImageLayoutService(det, to_image=True, crop_image=True)
106
+ pipe_component_list.append(d_layout_text)
107
+
108
+ d_text = dd.TextExtractionService(rec, extract_from_roi="WORD")
109
+ pipe_component_list.append(d_text)
110
+
111
+ match = dd.MatchingService(
112
+ parent_categories=cfg.WORD_MATCHING.PARENTAL_CATEGORIES,
113
+ child_categories=dd.LayoutType.word,
114
+ matching_rule=cfg.WORD_MATCHING.RULE,
115
+ threshold=cfg.WORD_MATCHING.IOU_THRESHOLD
116
+ if cfg.WORD_MATCHING.RULE in ["iou"]
117
+ else cfg.WORD_MATCHING.IOA_THRESHOLD,
118
+ )
119
+ pipe_component_list.append(match)
120
+ order = dd.TextOrderService(
121
+ text_container=dd.LayoutType.word,
122
+ floating_text_block_names=[dd.LayoutType.title, dd.LayoutType.text, dd.LayoutType.list],
123
+ text_block_names=[
124
+ dd.LayoutType.title,
125
+ dd.LayoutType.text,
126
+ dd.LayoutType.list,
127
+ dd.LayoutType.cell,
128
+ dd.CellType.header,
129
+ dd.CellType.body,
130
+ ],
131
+ )
132
+ pipe_component_list.append(order)
133
+
134
+ pipe = dd.DoctectionPipe(pipeline_component_list=pipe_component_list)
135
+
136
+ return pipe
137
+
138
+
139
+ def prepare_output(dp, add_table, add_ocr):
140
  out = dp.as_dict()
141
  out.pop("image")
 
 
142
 
143
+ layout_items = dp.items
144
+ if add_ocr:
145
+ layout_items.sort(key=lambda x: x.reading_order)
146
+ layout_items_str = ""
147
+ for item in layout_items:
148
+ layout_items_str += f"\n {item.layout_type}: {item.text}"
149
+ if add_table:
150
+ html_list = [table.html for table in dp.tables]
151
+ if html_list:
152
+ html = html_list[0]
153
+ else:
154
+ html = None
155
+ else:
156
+ html = None
157
+
158
+ return dp.viz(show_table_structure=False), layout_items_str, html, out
159
+
160
+
161
+ def analyze_image(img, pdf, attributes):
162
+
163
+ # creating an image object and passing to the analyzer by using dataflows
164
+ add_table = _DETECTIONS[0] in attributes
165
+ add_ocr = _DETECTIONS[1] in attributes
166
+
167
+ analyzer = build_gradio_analyzer(add_table, add_table, add_ocr)
168
+
169
+ if img is not None:
170
+ image = dd.Image(file_name="input.png", location="")
171
+ image.image = img[:, :, ::-1]
172
+
173
+ df = DataFromList(lst=[image])
174
+ df = analyzer.analyze(dataset_dataflow=df)
175
+ elif pdf:
176
+ df = analyzer.analyze(path=pdf.name, max_datapoints=3)
177
+ else:
178
+ raise ValueError
179
+
180
+ df.reset_state()
181
+ df_iter = iter(df)
182
+
183
+ dp = next(df_iter)
184
+
185
+ return prepare_output(dp, add_table, add_ocr)
186
+
187
+
188
+ demo = gr.Blocks(css="scrollbar.css")
189
+
190
+ with demo:
191
+ with gr.Box():
192
+ gr.Markdown("<h1><center>deepdoctection - A Document AI Package</center></h1>")
193
+ gr.Markdown("<strong>deep</strong>doctection is a Python library that orchestrates document extraction"
194
+ " and document layout analysis tasks using deep learning models. It does not implement models"
195
+ " but enables you to build pipelines using highly acknowledged libraries for object detection,"
196
+ " OCR and selected NLP tasks and provides an integrated frameworks for fine-tuning, evaluating"
197
+ " and running models.\n This pipeline consists of a stack of models powered by <strong>Detectron2"
198
+ "</strong> for layout analysis and table recognition and <strong>DocTr</strong> for OCR.")
199
+ with gr.Box():
200
+ gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
201
+ with gr.Row():
202
+ with gr.Column():
203
+ with gr.Tab("Image upload"):
204
+ with gr.Column():
205
+ inputs = gr.Image(type='numpy', label="Original Image")
206
+ with gr.Tab("PDF upload (only first image will be processed)"):
207
+ with gr.Column():
208
+ inputs_pdf = gr.File(label="PDF")
209
+ with gr.Column():
210
+ gr.Examples(
211
+ examples=[path.join(getcwd(), "sample_1.jpg"), path.join(getcwd(), "sample_2.png")],
212
+ inputs = inputs)
213
+
214
+ with gr.Row():
215
+ tok_input = gr.CheckboxGroup(
216
+ _DETECTIONS, value=_DETECTIONS, label="Additional extractions", interactive=True)
217
+ with gr.Row():
218
+ btn = gr.Button("Run model", variant="primary")
219
 
220
+ with gr.Box():
221
+ with gr.Row():
222
+ with gr.Column():
223
+ gr.Markdown("<h2><center>Text output</center></h2>")
224
+ gr.Markdown("Will only show contiguous text from text blocks, titles and lists")
225
+ image_text = gr.Textbox()
226
+ gr.Markdown("<h2><center>First table</center></h2>")
227
+ html = gr.HTML()
228
+ gr.Markdown("<h2><center>JSON output</center></h2>")
229
+ json = gr.JSON()
230
+ with gr.Column():
231
+ gr.Markdown("<h2><center>Layout detection</center></h2>")
232
+ image_output = gr.Image(type="numpy", label="Output Image")
233
 
234
+ btn.click(fn=analyze_image, inputs=[inputs, inputs_pdf, tok_input], outputs=[image_output, image_text, html, json])
235
 
236
+ demo.launch()