deepdoctection / app.py
JaMe76's picture
update space
86bfe12
raw
history blame
No virus
12.7 kB
import os
import time
import importlib.metadata
from os import getcwd, path, environ
import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList
from dd_addons.extern import PdfTextDetector, PostProcessor, get_xsl_path
from dd_addons.pipe.conn import PostProcessorService
import gradio as gr
from botocore.config import Config
from dotenv import load_dotenv
load_dotenv()
def check_additional_requirements():
if not dd.detectron2_available():
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
if importlib.util.find_spec("gradio") is not None:
if importlib.metadata.version("gradio")!="3.4.1":
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.4.1")
else:
os.system("pip install gradio==3.4.1")
os.system(os.environ["DD_ADDONS"])
return
check_additional_requirements()
# work around: https://discuss.huggingface.co/t/how-to-install-a-specific-version-of-gradio-in-spaces/13552
_DD_ONE = "conf_dd_one.yaml"
_XSL_PATH = get_xsl_path()
dd.ModelCatalog.register("xrf_layout/model_final_inf_only.pt",dd.ModelProfile(
name="xrf_layout/model_final_inf_only.pt",
description="layout_detection/morning-dragon-114",
config="xrf_dd/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274632215],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_LAYOUT"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.text,
"2": dd.LayoutType.title,
"3": dd.LayoutType.list,
"4": dd.LayoutType.table,
"5": dd.LayoutType.figure},
model_wrapper="D2FrcnnDetector",
))
dd.ModelCatalog.register("xrf_cell/model_final_inf_only.pt", dd.ModelProfile(
name="xrf_cell/model_final_inf_only.pt",
description="cell_detection/restful-eon-6",
config="xrf_dd/cell/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274583063],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_CELL"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.cell},
model_wrapper="D2FrcnnDetector",
))
dd.ModelCatalog.register("xrf_item/model_final_inf_only.pt", dd.ModelProfile(
name="xrf_item/model_final_inf_only.pt",
description="item_detection/firm_plasma_14",
config="xrf_dd/item/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274595351],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_ITEM"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.row, "2": dd.LayoutType.column},
model_wrapper="D2FrcnnDetector",
))
# Set up of the configuration and logging. Models are globally defined, so that they are not re-loaded once the input
# updates
cfg = dd.set_config_by_yaml(path.join(getcwd(),_DD_ONE))
cfg.freeze(freezed=False)
cfg.DEVICE = "cpu"
cfg.freeze()
# layout detector
layout_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2LAYOUT)
layout_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2LAYOUT)
categories_layout = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2LAYOUT).categories
assert categories_layout is not None
assert layout_weights_path is not None
d_layout = dd.D2FrcnnDetector(layout_config_path, layout_weights_path, categories_layout, device=cfg.DEVICE)
# cell detector
cell_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2CELL)
cell_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2CELL)
categories_cell = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2CELL).categories
assert categories_cell is not None
d_cell = dd.D2FrcnnDetector(cell_config_path, cell_weights_path, categories_cell, device=cfg.DEVICE)
# row/column detector
item_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2ITEM)
item_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2ITEM)
categories_item = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2ITEM).categories
assert categories_item is not None
d_item = dd.D2FrcnnDetector(item_config_path, item_weights_path, categories_item, device=cfg.DEVICE)
# pdf miner
pdf_text = PdfTextDetector(_XSL_PATH)
# text detector
credentials_kwargs={"aws_access_key_id": os.environ["ACCESS_KEY"],
"aws_secret_access_key": os.environ["SECRET_KEY"],
"config": Config(region_name=os.environ["REGION"])}
tex_text = dd.TextractOcrDetector(**credentials_kwargs)
def build_gradio_analyzer():
"""Building the Detectron2/DocTr analyzer based on the given config"""
cfg.freeze(freezed=False)
cfg.TAB = True
cfg.TAB_REF = True
cfg.OCR = True
cfg.freeze()
pipe_component_list = []
layout = dd.ImageLayoutService(d_layout, to_image=True, crop_image=True)
pipe_component_list.append(layout)
nms_service = dd.AnnotationNmsService(nms_pairs=cfg.LAYOUT_NMS_PAIRS.COMBINATIONS,
thresholds=cfg.LAYOUT_NMS_PAIRS.THRESHOLDS)
pipe_component_list.append(nms_service)
if cfg.TAB:
detect_result_generator = dd.DetectResultGenerator(categories_cell)
cell = dd.SubImageLayoutService(d_cell, dd.LayoutType.table, {1: 6}, detect_result_generator)
pipe_component_list.append(cell)
detect_result_generator = dd.DetectResultGenerator(categories_item)
item = dd.SubImageLayoutService(d_item, dd.LayoutType.table, {1: 7, 2: 8}, detect_result_generator)
pipe_component_list.append(item)
table_segmentation = dd.TableSegmentationService(
cfg.SEGMENTATION.ASSIGNMENT_RULE,
cfg.SEGMENTATION.THRESHOLD_ROWS,
cfg.SEGMENTATION.THRESHOLD_COLS,
cfg.SEGMENTATION.FULL_TABLE_TILING,
cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
cfg.SEGMENTATION.STRETCH_RULE
)
pipe_component_list.append(table_segmentation)
if cfg.TAB_REF:
table_segmentation_refinement = dd.TableSegmentationRefinementService()
pipe_component_list.append(table_segmentation_refinement)
if cfg.OCR:
d_text = dd.TextExtractionService(pdf_text)
pipe_component_list.append(d_text)
t_text = dd.TextExtractionService(tex_text,skip_if_text_extracted=True)
pipe_component_list.append(t_text)
match_words = dd.MatchingService(
parent_categories=cfg.WORD_MATCHING.PARENTAL_CATEGORIES,
child_categories=cfg.WORD_MATCHING.CHILD_CATEGORIES,
matching_rule=cfg.WORD_MATCHING.RULE,
threshold=cfg.WORD_MATCHING.THRESHOLD,
max_parent_only=cfg.WORD_MATCHING.MAX_PARENT_ONLY
)
pipe_component_list.append(match_words)
order = dd.TextOrderService(
text_container=cfg.TEXT_ORDERING.TEXT_CONTAINER,
floating_text_block_categories=cfg.TEXT_ORDERING.FLOATING_TEXT_BLOCK,
text_block_categories=cfg.TEXT_ORDERING.TEXT_BLOCK,
include_residual_text_container=cfg.TEXT_ORDERING.TEXT_CONTAINER_TO_TEXT_BLOCK)
pipe_component_list.append(order)
pipe = dd.DoctectionPipe(pipeline_component_list=pipe_component_list)
post_processor = PostProcessor("deepdoctection", **credentials_kwargs)
post_service = PostProcessorService(post_processor)
pipe_component_list.append(post_service)
return pipe
def analyze_image(img, pdf, max_datapoints):
# creating an image object and passing to the analyzer by using dataflows
analyzer = build_gradio_analyzer()
if img is not None:
image = dd.Image(file_name=str(time.time()).replace(".","") + ".png", location="")
image.image = img[:, :, ::-1]
df = DataFromList(lst=[image])
df = analyzer.analyze(dataset_dataflow=df)
elif pdf:
df = analyzer.analyze(path=pdf.name, max_datapoints=max_datapoints)
else:
raise ValueError
df.reset_state()
layout_items_str = ""
jsonl_out = []
dpts = []
html_list = []
for dp in df:
dpts.append(dp)
out = dp.as_dict()
jsonl_out.append(out)
out.pop("_image")
layout_items = [layout for layout in dp.layouts if layout.reading_order is not None]
layout_items.sort(key=lambda x: x.reading_order)
layout_items_str += f"\n\n -------- PAGE NUMBER: {dp.page_number+1} ------------- \n"
for item in layout_items:
layout_items_str += f"\n {item.category_name}: {item.text}"
html_list.extend([table.html for table in dp.tables])
if html_list:
html = ("<br /><br /><br />").join(html_list)
else:
html = None
return [dp.viz(show_cells=False) for dp in dpts], layout_items_str, html, jsonl_out
demo = gr.Blocks(css="scrollbar.css")
with demo:
with gr.Box():
gr.Markdown("<h1><center>deepdoctection - A Document AI Package</center></h1>")
gr.Markdown("<strong>deep</strong>doctection is a Python library that orchestrates document extraction"
" and document layout analysis tasks using deep learning models. It does not implement models"
" but enables you to build pipelines using highly acknowledged libraries for object detection,"
" OCR and selected NLP tasks and provides an integrated frameworks for fine-tuning, evaluating"
" and running models.<br />"
"This pipeline consists of a stack of models powered by <strong>Detectron2"
"</strong> for layout analysis and table recognition. OCR will be provided as well. You can process"
"an image or even a PDF-document. Up to nine pages can be processed. <br />")
gr.Markdown("[https://github.com/deepdoctection/deepdoctection](https://github.com/deepdoctection/deepdoctection)")
with gr.Box():
gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
with gr.Row():
with gr.Column():
with gr.Tab("Image upload"):
with gr.Column():
inputs = gr.Image(type='numpy', label="Original Image")
with gr.Tab("PDF upload (only first image will be processed) *"):
with gr.Column():
inputs_pdf = gr.File(label="PDF")
gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
with gr.Column():
gr.Examples(
examples=[path.join(getcwd(), "sample_1.jpg"), path.join(getcwd(), "sample_2.png")],
inputs = inputs)
gr.Examples(examples=[path.join(getcwd(), "sample_3.pdf")], inputs = inputs_pdf)
with gr.Row():
max_imgs = gr.Slider(1, 8, value=2, step=1, label="Number of pages in multi page PDF",
info="Will stop after 9 pages")
with gr.Row():
btn = gr.Button("Run model", variant="primary")
with gr.Box():
gr.Markdown("<h2><center>Outputs</center></h2>")
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>Contiguous text</strong></center>")
image_text = gr.Textbox()
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>Layout detection</strong></center>")
gallery = gr.Gallery(
label="Output images", show_label=False, elem_id="gallery"
).style(grid=2)
with gr.Row():
with gr.Box():
gr.Markdown("<center><strong>Table</strong></center>")
html = gr.HTML()
with gr.Row():
with gr.Box():
gr.Markdown("<center><strong>JSON</strong></center>")
json = gr.JSON()
btn.click(fn=analyze_image, inputs=[inputs, inputs_pdf, max_imgs],
outputs=[gallery, image_text, html, json])
demo.launch()