Spaces:
Runtime error
Runtime error
#-*- coding: UTF-8 -*- | |
# Copyright 2022 the HuggingFace Team. | |
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import traceback | |
import gradio as gr | |
from paddlenlp import Taskflow | |
from paddlenlp.utils.doc_parser import DocParser | |
doc_parser = DocParser() | |
task_instance = Taskflow( | |
"information_extraction", | |
model="uie-x-base", | |
task_path="PaddlePaddle/uie-x-base", | |
from_hf_hub=True, | |
schema="") | |
examples = [ | |
[ | |
"business_card.png", | |
"Name;Title;Web Link;Email;Address", | |
], | |
[ | |
"license.jpeg", | |
"Name;DOB;ISS;EXP", | |
], | |
[ | |
"invoice.jpeg", | |
"名称;纳税人识别号;开票日期", | |
], | |
[ | |
"custom.jpeg", | |
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号" | |
], | |
[ | |
"resume.png", | |
"职位;年龄;学校|时间;学校|专业", | |
], | |
] | |
example_files = { | |
"Name;Title;Web Link;Email;Address": "business_card.png", | |
"Name;DOB;ISS;EXP": "license.jpeg", | |
"职位;年龄;学校|时间;学校|专业": "resume.png", | |
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号": "custom.jpeg", | |
"名称;纳税人识别号;开票日期": "invoice.jpeg", | |
} | |
lang_map = { | |
"resume.png": "ch", | |
"custom.jpeg": "ch", | |
"business_card.png": "en", | |
"invoice.jpeg": "ch", | |
"license.jpeg": "en", | |
} | |
def dbc2sbc(s): | |
rs = "" | |
for char in s: | |
code = ord(char) | |
if code == 0x3000: | |
code = 0x0020 | |
else: | |
code -= 0xfee0 | |
if not (0x0021 <= code and code <= 0x7e): | |
rs += char | |
continue | |
rs += chr(code) | |
return rs | |
def process_path(path): | |
error = None | |
if path: | |
try: | |
images_list = [doc_parser.read_image(path)] | |
return ( | |
path, | |
gr.update(visible=True, value=images_list), | |
gr.update(visible=True), | |
gr.update(visible=False, value=None), | |
gr.update(visible=False, value=None), | |
None, | |
) | |
except Exception as e: | |
traceback.print_exc() | |
error = str(e) | |
return ( | |
None, | |
gr.update(visible=False, value=None), | |
gr.update(visible=False), | |
gr.update(visible=False, value=None), | |
gr.update(visible=False, value=None), | |
gr.update(visible=True, value=error) if error is not None else None, | |
None, | |
) | |
def process_upload(file): | |
if file: | |
return process_path(file.name) | |
else: | |
return ( | |
None, | |
gr.update(visible=False, value=None), | |
gr.update(visible=False), | |
gr.update(visible=False, value=None), | |
gr.update(visible=False, value=None), | |
None, | |
) | |
def get_schema(schema_str): | |
def _is_ch(s): | |
for ch in s: | |
if "\u4e00" <= ch <= "\u9fff": | |
return True | |
return False | |
schema_lang = "ch" if _is_ch(schema_str) else "en" | |
schema = schema_str.split(";") | |
schema_list = [] | |
for s in schema: | |
cand = s.split("|") | |
if len(cand) == 1: | |
schema_list.append(cand[0]) | |
else: | |
subject = cand[0] | |
relations = cand[1:] | |
added = False | |
for a in schema_list: | |
if isinstance(a, dict): | |
if subject in a.keys(): | |
a[subject].extend(relations) | |
added = True | |
break | |
if not added: | |
a = {subject: relations} | |
schema_list.append(a) | |
return schema_list, schema_lang | |
def run_taskflow(document, schema, argument): | |
task_instance.set_schema(schema) | |
task_instance.set_argument(argument) | |
return task_instance({'doc': document}) | |
def process_doc(document, schema, ocr_lang, layout_analysis): | |
if not schema: | |
schema = '时间;组织机构;人物' | |
if document is None: | |
return None, None | |
schema, schema_lang = get_schema(dbc2sbc(schema)) | |
argument = { | |
"ocr_lang": ocr_lang, | |
"schema_lang": schema_lang, | |
"layout_analysis": layout_analysis | |
} | |
prediction = run_taskflow(document, schema, argument)[0] | |
img_show = doc_parser.write_image_with_results( | |
document, | |
result=prediction, | |
return_image=True) | |
img_list = [img_show] | |
return ( | |
gr.update(visible=True, value=img_list), | |
gr.update(visible=True, value=prediction), | |
) | |
def load_example_document(img, schema, ocr_lang, layout_analysis): | |
if img is not None: | |
document = example_files[schema] | |
choice = lang_map[document].split("-") | |
ocr_lang = choice[0] | |
layout_analysis = False if len(choice) == 1 else True | |
preview, answer = process_doc(document, schema, ocr_lang, layout_analysis) | |
return document, schema, preview, gr.update(visible=True), answer | |
else: | |
return None, None, None, gr.update(visible=False), None | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
CSS = """ | |
#prompt input { | |
font-size: 16px; | |
} | |
#url-textbox { | |
padding: 0 !important; | |
} | |
#short-upload-box .w-full { | |
min-height: 10rem !important; | |
} | |
/* I think something like this can be used to re-shape | |
* the table | |
*/ | |
/* | |
.gr-samples-table tr { | |
display: inline; | |
} | |
.gr-samples-table .p-2 { | |
width: 100px; | |
} | |
*/ | |
#select-a-file { | |
width: 100%; | |
} | |
#file-clear { | |
padding-top: 2px !important; | |
padding-bottom: 2px !important; | |
padding-left: 8px !important; | |
padding-right: 8px !important; | |
margin-top: 10px; | |
} | |
.gradio-container .gr-button-primary { | |
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); | |
border: 1px solid #B0DCCC; | |
border-radius: 8px; | |
color: #1B8700; | |
} | |
.gradio-container.dark button#submit-button { | |
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); | |
border: 1px solid #B0DCCC; | |
border-radius: 8px; | |
color: #1B8700 | |
} | |
table.gr-samples-table tr td { | |
border: none; | |
outline: none; | |
} | |
table.gr-samples-table tr td:first-of-type { | |
width: 0%; | |
} | |
div#short-upload-box div.absolute { | |
display: none !important; | |
} | |
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div { | |
gap: 0px 2%; | |
} | |
gradio-app div div div div.w-full, .gradio-app div div div div.w-full { | |
gap: 0px; | |
} | |
gradio-app h2, .gradio-app h2 { | |
padding-top: 10px; | |
} | |
#answer { | |
overflow-y: scroll; | |
color: white; | |
background: #666; | |
border-color: #666; | |
font-size: 20px; | |
font-weight: bold; | |
} | |
#answer span { | |
color: white; | |
} | |
#answer textarea { | |
color:white; | |
background: #777; | |
border-color: #777; | |
font-size: 18px; | |
} | |
#url-error input { | |
color: red; | |
} | |
""" | |
with gr.Blocks(css=CSS) as demo: | |
gr.HTML(read_content("header.html")) | |
gr.Markdown( | |
"Open-sourced by PaddleNLP, **UIE-X** is a universal information extraction engine for both scanned document and text inputs. It supports Entity Extraction, Relation Extraction and Event Extraction tasks." | |
"UIE-X performs well on a zero-shot settings, which is enabled by a flexible schema that allows you to specify extraction targets with simple natural language." | |
"Moreover, on PaddleNLP, we provide a comprehensive and easy-to-use fine-tuning and few-shot customization workflow." | |
"For more details, please visit our [GitHub](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction)" | |
) | |
document = gr.Variable() | |
is_text = gr.Variable() | |
example_schema = gr.Textbox(visible=False) | |
example_image = gr.Image(visible=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown("## 1. Select a file 选择文件", elem_id="select-a-file") | |
img_clear_button = gr.Button( | |
"Clear", variant="secondary", elem_id="file-clear", visible=False | |
) | |
image = gr.Gallery(visible=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
with gr.Row(): | |
url = gr.Textbox( | |
show_label=False, | |
placeholder="URL", | |
lines=1, | |
max_lines=1, | |
elem_id="url-textbox", | |
) | |
submit = gr.Button("Get") | |
url_error = gr.Textbox( | |
visible=False, | |
elem_id="url-error", | |
max_lines=1, | |
interactive=False, | |
label="Error", | |
) | |
gr.Markdown("— or —") | |
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box") | |
gr.Examples( | |
examples=examples, | |
inputs=[example_image, example_schema], | |
) | |
with gr.Column(): | |
gr.Markdown("## 2. Information Extraction 信息抽取 ") | |
gr.Markdown("### 👉 Set a schema 设置schema") | |
gr.Markdown("Entity extraction: entity type should be separated by ';', e.g. **Person;Organization**") | |
gr.Markdown("实体抽取:实体类别之间以';'分割,例如 **人物;组织机构**") | |
gr.Markdown("Relation extraction: set the subject and relation type, separated by '|', e.g. **Person|Date;Person|Email**") | |
gr.Markdown("关系抽取:需配置主体和关系类别,中间以'|'分割,例如 **人物|出生时间;人物|邮箱**") | |
gr.Markdown("### 👉 Model customization 模型定制") | |
gr.Markdown("We recommend to further improve the extraction performance in specific domain through the process of [data annotation & fine-tuning](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)") | |
gr.Markdown("我们建议通过[数据标注+微调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)的流程进一步增强模型在特定场景的效果") | |
schema = gr.Textbox( | |
label="Schema", | |
placeholder="e.g. Name|Company;Name|Position;Email;Phone Number", | |
lines=1, | |
max_lines=1, | |
) | |
ocr_lang = gr.Radio( | |
choices=["ch", "en"], | |
value="en", | |
label="OCR语言 / OCR Language (Please choose ch for Chinese images.)", | |
) | |
layout_analysis = gr.Radio( | |
choices=["yes", "no"], | |
value="no", | |
label="版面分析 / Layout analysis (Better extraction for multi-line text)", | |
) | |
with gr.Row(): | |
clear_button = gr.Button("Clear", variant="secondary") | |
submit_button = gr.Button( | |
"Submit", variant="primary", elem_id="submit-button" | |
) | |
with gr.Column(): | |
output = gr.JSON(label="Output", visible=False) | |
for cb in [img_clear_button, clear_button]: | |
cb.click( | |
lambda _: ( | |
gr.update(visible=False, value=None), | |
None, | |
gr.update(visible=False, value=None), | |
gr.update(visible=False), | |
None, | |
None, | |
None, | |
gr.update(visible=False, value=None), | |
None, | |
), | |
inputs=clear_button, | |
outputs=[ | |
image, | |
document, | |
output, | |
img_clear_button, | |
example_image, | |
upload, | |
url, | |
url_error, | |
schema, | |
], | |
) | |
upload.change( | |
fn=process_upload, | |
inputs=[upload], | |
outputs=[document, image, img_clear_button, output, url_error], | |
) | |
submit.click( | |
fn=process_path, | |
inputs=[url], | |
outputs=[document, image, img_clear_button, output, url_error], | |
) | |
schema.submit( | |
fn=process_doc, | |
inputs=[document, schema, ocr_lang, layout_analysis], | |
outputs=[image, output], | |
) | |
submit_button.click( | |
fn=process_doc, | |
inputs=[document, schema, ocr_lang, layout_analysis], | |
outputs=[image, output], | |
) | |
example_image.change( | |
fn=load_example_document, | |
inputs=[example_image, example_schema, ocr_lang, layout_analysis], | |
outputs=[document, schema, image, img_clear_button, output], | |
) | |
gr.HTML(read_content("footer.html")) | |
if __name__ == "__main__": | |
demo.launch(enable_queue=False) | |