from transformers import DonutProcessor, VisionEncoderDecoderModel from word2number import w2n from dateutil import relativedelta from datetime import datetime from word2number import w2n from textblob import Word from PIL import Image import torch import re CHEQUE_PARSER_MODEL = "shivi/donut-base-cheque" TASK_PROMPT = "" device = "cuda" if torch.cuda.is_available() else "cpu" def load_donut_model_and_processor(): donut_processor = DonutProcessor.from_pretrained(CHEQUE_PARSER_MODEL) model = VisionEncoderDecoderModel.from_pretrained(CHEQUE_PARSER_MODEL) model.to(device) return donut_processor, model def prepare_data_using_processor(donut_processor,image_path): ## Pass image through donut processor's feature extractor and retrieve image tensor image = load_image(image_path) print("type image:", type(image)) pixel_values = donut_processor(image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) ## Pass task prompt for document (cheque) parsing task to donut processor's tokenizer and retrieve the input_ids decoder_input_ids = donut_processor.tokenizer(TASK_PROMPT, add_special_tokens=False, return_tensors="pt")["input_ids"] decoder_input_ids = decoder_input_ids.to(device) return pixel_values, decoder_input_ids def load_image(image_path): image = Image.open(image_path).convert("RGB") return image def parse_cheque_with_donut(input_image_path): donut_processor, model = load_donut_model_and_processor() cheque_image_tensor, input_for_decoder = prepare_data_using_processor(donut_processor,input_image_path) outputs = model.generate(cheque_image_tensor, decoder_input_ids=input_for_decoder, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=donut_processor.tokenizer.pad_token_id, eos_token_id=donut_processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[donut_processor.tokenizer.unk_token_id]], return_dict_in_generate=True, output_scores=True,) decoded_output_sequence = donut_processor.batch_decode(outputs.sequences)[0] extracted_cheque_details = decoded_output_sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "") ## remove task prompt from token sequence cleaned_cheque_details = re.sub(r"<.*?>", "", extracted_cheque_details, count=1).strip() ## generate ordered json sequence from output token sequence cheque_details_json = donut_processor.token2json(cleaned_cheque_details) print("cheque_details_json:",cheque_details_json['cheque_details']) ## extract required fields from predicted json amt_in_words = cheque_details_json['cheque_details'][0]['amt_in_words'] amt_in_figures = cheque_details_json['cheque_details'][1]['amt_in_figures'] macthing_amts = match_legal_and_courstesy_amount(amt_in_words,amt_in_figures) payee_name = cheque_details_json['cheque_details'][2]['payee_name'] ## In the cheques dataset used to train the model -> all the cheques are dated '06/05/22' ## Train model to extract cheque date -> to do cheque_date = '06/05/2022' stale_cheque = check_if_cheque_is_stale(cheque_date) return payee_name,amt_in_words,amt_in_figures,cheque_date,macthing_amts,stale_cheque def spell_correction(amt_in_words): corrected_amt_in_words ='' words = amt_in_words.split() words = [word.lower() for word in words] for word in words: word = Word(word) corrected_word = word.correct()+' ' corrected_amt_in_words += corrected_word return corrected_amt_in_words def match_legal_and_courstesy_amount(legal_amount,courtesy_amount): macthing_amts = False if len(legal_amount) == 0: return macthing_amts corrected_amt_in_words = spell_correction(legal_amount) print("corrected_amt_in_words:",corrected_amt_in_words) numeric_legal_amt = w2n.word_to_num(corrected_amt_in_words) print("numeric_legal_amt:",numeric_legal_amt) if int(numeric_legal_amt) == int(courtesy_amount): macthing_amts = True return macthing_amts def check_if_cheque_is_stale(cheque_issue_date): stale_check = False current_date = datetime.now().strftime('%d/%m/%Y') current_date_ = datetime.strptime(current_date, "%d/%m/%Y") cheque_issue_date_ = datetime.strptime(cheque_issue_date, "%d/%m/%Y") relative_diff = relativedelta.relativedelta(current_date_, cheque_issue_date_) months_difference = (relative_diff.years * 12) + relative_diff.months print("months_difference:",months_difference) if months_difference > 3: stale_check = True return stale_check