Donut-spaces / app.py
xelpmocAI's picture
Commented article var for gr.Interface
c659e77 verified
raw
history blame contribute delete
No virus
1.97 kB
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
print(image)
print(f"Type of Image {image}")
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
# article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>"
demo = gr.Interface(
fn=process_document,
inputs="image",
outputs="json",
title="Template-Free OCR model",
# article=article,
enable_queue=True,
examples=[["example.png"], ["example_2.png"], ["example_3.png"]],
cache_examples=False)
demo.launch()