LouiseWiilog's picture
Update app.py
c745fd7
import io
import os
import boto3
import traceback
import re
import logging
import gradio as gr
from PIL import Image, ImageDraw
from docquery.document import load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import pipeline
# avoid ssl errors
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Init models
layoutlm_pipeline = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
lilt_model = AutoModelForQuestionAnswering.from_pretrained(
"nielsr/lilt-xlm-roberta-base"
)
donut_processor = DonutProcessor.from_pretrained(
"naver-clova-ix/donut-base-finetuned-docvqa"
)
donut_model = VisionEncoderDecoderModel.from_pretrained(
"naver-clova-ix/donut-base-finetuned-docvqa"
)
TEXTRACT = "Textract Query"
LAYOUTLM = "LayoutLM"
DONUT = "Donut"
LILT = "LiLT"
def image_to_byte_array(image: Image) -> bytes:
image_as_byte_array = io.BytesIO()
image.save(image_as_byte_array, format="PNG")
image_as_byte_array = image_as_byte_array.getvalue()
return image_as_byte_array
def run_textract(question, document):
logger.info(f"Running Textract model.")
image_as_byte_base64 = image_to_byte_array(image=document.b)
response = boto3.client("textract").analyze_document(
Document={
"Bytes": image_as_byte_base64,
},
FeatureTypes=[
"QUERIES",
],
QueriesConfig={
"Queries": [
{
"Text": question,
"Pages": [
"*",
],
},
]
},
)
logger.info(f"Output of Textract model {response}.")
for element in response["Blocks"]:
if element["BlockType"] == "QUERY_RESULT":
return {
"score": element["Confidence"],
"answer": element["Text"],
# "word_ids": element
}
else:
Exception("No QUERY_RESULT found in the response from Textract.")
def run_layoutlm(question, document):
logger.info(f"Running layoutlm model.")
result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
logger.info(f"Output of layoutlm model {result}.")
# [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
return {
"score": result["score"],
"answer": result["answer"],
"word_ids": [result["start"], result["end"]],
"page": 0,
}
def run_lilt(question, document):
logger.info(f"Running lilt model.")
# use this model + tokenizer
processed_document = document.context["image"][0][1]
words = [x[0] for x in processed_document]
boxes = [x[1] for x in processed_document]
encoding = lilt_tokenizer(
text=question,
text_pair=words,
boxes=boxes,
add_special_tokens=True,
return_tensors="pt",
)
outputs = lilt_model(**encoding)
logger.info(f"Output for lilt model {outputs}.")
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = encoding.input_ids[
0, answer_start_index: answer_end_index + 1
]
predict_answer = lilt_tokenizer.decode(
predict_answer_tokens, skip_special_tokens=True
)
return {
"score": "n/a",
"answer": predict_answer,
# "word_ids": element
}
def run_donut(question, document):
logger.info(f"Running donut model.")
# prepare encoder inputs
pixel_values = donut_processor(
document.context["image"][0][0], return_tensors="pt"
).pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = donut_processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
# generate answer
outputs = donut_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=donut_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=donut_processor.tokenizer.pad_token_id,
eos_token_id=donut_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
logger.info(f"Output for donut {outputs}")
sequence = donut_processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
donut_processor.tokenizer.pad_token, ""
)
sequence = re.sub(
r"<.*?>", "", sequence, count=1
).strip() # remove first task start token
result = donut_processor.token2json(sequence)
return {
"score": "n/a",
"answer": result["answer"],
# "word_ids": element
}
def process_path(path):
error = None
if path:
try:
document = load_document(path)
return (
document,
gr.update(visible=True, value=document.preview),
gr.update(visible=True),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
except Exception as e:
traceback.print_exc()
error = str(e)
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=True, value=error) if error is not None else None,
None,
)
def process_upload(file):
if file:
return process_path(file.name)
else:
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
def lift_word_boxes(document, page):
return document.context["image"][page][1]
def expand_bbox(word_boxes):
if len(word_boxes) == 0:
return None
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
return [min_x, min_y, max_x, max_y]
# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height, padding=0.005):
min_x, min_y, max_x, max_y = [c / 1000 for c in box]
if padding != 0:
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(max_x + padding, 1)
max_y = min(max_y + padding, 1)
return [min_x * width, min_y * height, max_x * width, max_y * height]
MODELS = {
LAYOUTLM: run_layoutlm,
DONUT: run_donut,
LILT: run_lilt,
TEXTRACT: run_textract,
}
def process_question(question, document, model=list(MODELS.keys())[0]):
if not question or document is None:
return None, None, None
logger.info(f"Running for model {model}")
prediction = MODELS[model](question=question, document=document)
logger.info(f"Got prediction {prediction}")
pages = [x.copy().convert("RGB") for x in document.preview]
text_value = prediction["answer"]
if "word_ids" in prediction:
logger.info(f"Setting bounding boxes.")
image = pages[prediction["page"]]
draw = ImageDraw.Draw(image, "RGBA")
word_boxes = lift_word_boxes(document, prediction["page"])
x1, y1, x2, y2 = normalize_bbox(
expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
image.width,
image.height,
)
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
return (
gr.update(visible=True, value=pages),
gr.update(visible=True, value=prediction),
gr.update(
visible=True,
value=text_value,
),
)
def load_example_document(img, question, model):
if img is not None:
document = ImageDocument(Image.fromarray(img), get_ocr_reader())
preview, answer, answer_text = process_question(question, document, model)
return document, question, preview, gr.update(visible=True), answer, answer_text
else:
return None, None, None, gr.update(visible=False), None, None
CSS = """
#question input {
font-size: 16px;
}
#url-textbox {
padding: 0 !important;
}
#short-upload-box .w-full {
min-height: 10rem !important;
}
/* I think something like this can be used to re-shape
* the table
*/
/*
.gr-samples-table tr {
display: inline;
}
.gr-samples-table .p-2 {
width: 100px;
}
*/
#select-a-file {
width: 100%;
}
#file-clear {
padding-top: 2px !important;
padding-bottom: 2px !important;
padding-left: 8px !important;
padding-right: 8px !important;
margin-top: 10px;
}
.gradio-container .gr-button-primary {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700;
}
.gradio-container.dark button#submit-button {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700
}
table.gr-samples-table tr td {
border: none;
outline: none;
}
table.gr-samples-table tr td:first-of-type {
width: 0%;
}
div#short-upload-box div.absolute {
display: none !important;
}
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
gap: 0px 2%;
}
gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
gap: 0px;
}
gradio-app h2, .gradio-app h2 {
padding-top: 10px;
}
#answer {
overflow-y: scroll;
color: white;
background: #666;
border-color: #666;
font-size: 20px;
font-weight: bold;
}
#answer span {
color: white;
}
#answer textarea {
color:white;
background: #777;
border-color: #777;
font-size: 18px;
}
#url-error input {
color: red;
}
"""
examples = [
[
"scenario-1.png",
"What is the final consignee?",
],
[
"scenario-1.png",
"What are the payment terms?",
],
[
"scenario-2.png",
"What is the actual manufacturer?",
],
[
"scenario-3.png",
'What is the "ship to" destination?',
],
[
"scenario-4.png",
"What is the color?",
],
[
"scenario-5.png",
'What is the "said to contain"?',
],
[
"scenario-5.png",
'What is the "Net Weight"?',
],
[
"scenario-5.png",
'What is the "Freight Collect"?',
],
[
"bill_of_lading_1.png",
"What is the shipper?",
],
[
"japanese-invoice.png",
"What is the total amount?",
],
[
"example-10.jpeg",
"What is mineral water price amount?"
]
]
with gr.Blocks(css=CSS) as demo:
gr.Markdown("# Document Question Answer Comparator")
gr.Markdown("""
This space compares some of the latest models that can be used commercially.
- [LayoutLM](https://huggingface.co/impira/layoutlm-document-qa) uses text/layout and images. Uses tesseract for OCR.
- [Donut](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) OCR free document understanding. Uses vision encoder for OCR and a text decoder for providing the answer
""")
document = gr.Variable()
example_question = gr.Textbox(visible=False)
example_image = gr.Image(visible=False)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
gr.Markdown("## 1. Select a file", elem_id="select-a-file")
img_clear_button = gr.Button(
"Clear", variant="secondary", elem_id="file-clear", visible=False
)
image = gr.Gallery(visible=False)
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
gr.Examples(
examples=examples,
inputs=[example_image, example_question],
)
with gr.Column() as col:
gr.Markdown("## 2. Ask a question")
question = gr.Textbox(
label="Question",
placeholder="e.g. What is the invoice number?",
lines=1,
max_lines=1,
)
model = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
with gr.Row():
clear_button = gr.Button("Clear", variant="secondary")
submit_button = gr.Button(
"Submit", variant="primary", elem_id="submit-button"
)
with gr.Column():
output_text = gr.Textbox(
label="Top Answer", visible=False, elem_id="answer"
)
output = gr.JSON(label="Output", visible=False)
for cb in [img_clear_button, clear_button]:
cb.click(
lambda _: (
gr.update(visible=False, value=None),
None,
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False),
None,
None,
None,
gr.update(visible=False, value=None),
None,
),
inputs=clear_button,
outputs=[
image,
document,
output,
output_text,
img_clear_button,
example_image,
upload,
question,
],
)
upload.change(
fn=process_upload,
inputs=[upload],
outputs=[document, image, img_clear_button, output, output_text],
)
question.submit(
fn=process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
submit_button.click(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
model.change(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
example_image.change(
fn=load_example_document,
inputs=[example_image, example_question, model],
outputs=[document, question, image, img_clear_button, output, output_text],
)
if __name__ == "__main__":
demo.launch(enable_queue=False)