Spaces:
Running
Running
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
from word2number import w2n | |
from dateutil import relativedelta | |
from datetime import datetime | |
from word2number import w2n | |
from textblob import Word | |
from PIL import Image | |
import torch | |
import re | |
CHEQUE_PARSER_MODEL = "shivi/donut-base-cheque" | |
TASK_PROMPT = "<s_cord-v2>" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_donut_model_and_processor(): | |
donut_processor = DonutProcessor.from_pretrained(CHEQUE_PARSER_MODEL) | |
model = VisionEncoderDecoderModel.from_pretrained(CHEQUE_PARSER_MODEL) | |
model.to(device) | |
return donut_processor, model | |
def prepare_data_using_processor(donut_processor,image_path): | |
## Pass image through donut processor's feature extractor and retrieve image tensor | |
image = load_image(image_path) | |
print("type image:", type(image)) | |
pixel_values = donut_processor(image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
## Pass task prompt for document (cheque) parsing task to donut processor's tokenizer and retrieve the input_ids | |
decoder_input_ids = donut_processor.tokenizer(TASK_PROMPT, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
decoder_input_ids = decoder_input_ids.to(device) | |
return pixel_values, decoder_input_ids | |
def load_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
return image | |
def parse_cheque_with_donut(input_image_path): | |
donut_processor, model = load_donut_model_and_processor() | |
cheque_image_tensor, input_for_decoder = prepare_data_using_processor(donut_processor,input_image_path) | |
outputs = model.generate(cheque_image_tensor, | |
decoder_input_ids=input_for_decoder, | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=donut_processor.tokenizer.pad_token_id, | |
eos_token_id=donut_processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[donut_processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
output_scores=True,) | |
decoded_output_sequence = donut_processor.batch_decode(outputs.sequences)[0] | |
extracted_cheque_details = decoded_output_sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "") | |
## remove task prompt from token sequence | |
cleaned_cheque_details = re.sub(r"<.*?>", "", extracted_cheque_details, count=1).strip() | |
## generate ordered json sequence from output token sequence | |
cheque_details_json = donut_processor.token2json(cleaned_cheque_details) | |
print("cheque_details_json:",cheque_details_json['cheque_details']) | |
## extract required fields from predicted json | |
amt_in_words = cheque_details_json['cheque_details'][0]['amt_in_words'] | |
amt_in_figures = cheque_details_json['cheque_details'][1]['amt_in_figures'] | |
macthing_amts = match_legal_and_courstesy_amount(amt_in_words,amt_in_figures) | |
payee_name = cheque_details_json['cheque_details'][2]['payee_name'] | |
## In the cheques dataset used to train the model -> all the cheques are dated '06/05/22' | |
## Train model to extract cheque date -> to do | |
cheque_date = '06/05/2022' | |
stale_cheque = check_if_cheque_is_stale(cheque_date) | |
return payee_name,amt_in_words,amt_in_figures,cheque_date,macthing_amts,stale_cheque | |
def spell_correction(amt_in_words): | |
corrected_amt_in_words ='' | |
words = amt_in_words.split() | |
words = [word.lower() for word in words] | |
for word in words: | |
word = Word(word) | |
corrected_word = word.correct()+' ' | |
corrected_amt_in_words += corrected_word | |
return corrected_amt_in_words | |
def match_legal_and_courstesy_amount(legal_amount,courtesy_amount): | |
macthing_amts = False | |
if len(legal_amount) == 0: | |
return macthing_amts | |
corrected_amt_in_words = spell_correction(legal_amount) | |
print("corrected_amt_in_words:",corrected_amt_in_words) | |
numeric_legal_amt = w2n.word_to_num(corrected_amt_in_words) | |
print("numeric_legal_amt:",numeric_legal_amt) | |
if int(numeric_legal_amt) == int(courtesy_amount): | |
macthing_amts = True | |
return macthing_amts | |
def check_if_cheque_is_stale(cheque_issue_date): | |
stale_check = False | |
current_date = datetime.now().strftime('%d/%m/%Y') | |
current_date_ = datetime.strptime(current_date, "%d/%m/%Y") | |
cheque_issue_date_ = datetime.strptime(cheque_issue_date, "%d/%m/%Y") | |
relative_diff = relativedelta.relativedelta(current_date_, cheque_issue_date_) | |
months_difference = (relative_diff.years * 12) + relative_diff.months | |
print("months_difference:",months_difference) | |
if months_difference > 3: | |
stale_check = True | |
return stale_check | |