shivi commited on
Commit
3bf331b
1 Parent(s): 136cca7

Upload 2 files

Browse files

Added gradio app set up for cheque easy

Files changed (2) hide show
  1. app.py +82 -0
  2. predict_cheque_parser.py +108 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import gradio as gr
4
+ from predict_cheque_parser import parse_cheque_with_donut
5
+
6
+ ##Create list of examples to be loaded
7
+ example_list = glob.glob("examples/cheque_parser/*")
8
+ faulty_cheques_list = glob.glob("examples/cheque_analyze/*")
9
+ example_list = list(map(lambda el:[el], example_list))
10
+ faulty_cheques_list = list(map(lambda el:[el], faulty_cheques_list))
11
+
12
+ demo = gr.Blocks(css="#warning {color: red}")
13
+
14
+ with demo:
15
+
16
+ gr.Markdown("# **<p align='center'>ChequeEasy: Banking with Transformers </p>**")
17
+ gr.Markdown("This space demonstrates the use of Donut proposed in this <a href=\"https://arxiv.org/abs/2111.15664/\">paper </a>")
18
+
19
+ with gr.Tabs():
20
+
21
+ with gr.TabItem("Cheque Parser"):
22
+ gr.Markdown("The module is used to extract details filled by a bank customer from cheques. At present the model is trained to extract details like - payee_name, amount_in_words, amount_in_figures. This model can be further trained to parse additional details like micr_code, cheque_number, account_number, etc")
23
+ with gr.Box():
24
+ gr.Markdown("**Upload Cheque**")
25
+ input_image_parse = gr.Image(type='filepath', label="Input Cheque")
26
+ with gr.Box():
27
+ gr.Markdown("**Parsed Cheque Data**")
28
+
29
+ payee_name = gr.Textbox(label="Payee Name")
30
+ amt_in_words = gr.Textbox(label="Courtesy Amount")
31
+ amt_in_figures = gr.Textbox(label="Legal Amount")
32
+ cheque_date = gr.Textbox(label="Cheque Date")
33
+
34
+ # micr_code = gr.Textbox(label="MICR code")
35
+ # cheque_number = gr.Textbox(label="Cheque Number")
36
+ # account_number = gr.Textbox(label="Account Number")
37
+
38
+ amts_matching = gr.Checkbox(label="Legal & Courtesy Amount Matching", elem_id="warning")
39
+ stale_check = gr.Checkbox(label="Stale Cheque")
40
+
41
+ with gr.Box():
42
+ gr.Markdown("**Predict**")
43
+ with gr.Row():
44
+ parse_cheque = gr.Button("Call Donut 🍩")
45
+
46
+ with gr.Column():
47
+ gr.Examples(example_list, [input_image_parse],
48
+ [payee_name,amt_in_words,amt_in_figures,cheque_date],parse_cheque_with_donut,cache_examples=False)
49
+ # micr_code,cheque_number,account_number,
50
+ # amts_matching, stale_check]#,cache_examples=True)
51
+
52
+
53
+ with gr.TabItem("Quality Analyzer"):
54
+ gr.Markdown("The module is used to detect any mistakes made by bank customers while filling out the cheque or while taking a snapshot of the cheque. At present the model is trained to find mistakes like -'object blocking cheque', 'overwriting in cheque'. ")
55
+ with gr.Box():
56
+ gr.Markdown("**Upload Cheque**")
57
+ input_image_detect = gr.Image(type='filepath',label="Input Cheque", show_label=True)
58
+
59
+ with gr.Box(): # with gr.Column():
60
+ gr.Markdown("**Cheque Quality Results:**")
61
+ output_detections = gr.Image(label="Analyzed Cheque Image", show_label=True)
62
+ output_text = gr.Textbox()
63
+
64
+ with gr.Box():
65
+ gr.Markdown("**Predict**")
66
+ with gr.Row():
67
+ analyze_cheque = gr.Button("Call YOLOS 🤙")
68
+
69
+ gr.Markdown("**Examples:**")
70
+
71
+ with gr.Column():
72
+ gr.Examples(faulty_cheques_list, input_image_detect, [output_detections, output_text])#, predict, cache_examples=True)
73
+
74
+
75
+ parse_cheque.click(parse_cheque_with_donut, inputs=input_image_parse, outputs=[payee_name,amt_in_words,amt_in_figures,cheque_date,amts_matching,stale_check])
76
+ # micr_code,cheque_number,account_number,
77
+ # amts_matching, stale_check])
78
+ # analyze_cheque.click(predict, inputs=input_image_detect, outputs=[output_detections, output_text])
79
+
80
+ gr.Markdown('\n Solution built by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
81
+
82
+ demo.launch(share=True, debug=True)
predict_cheque_parser.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
2
+ from word2number import w2n
3
+ from dateutil import relativedelta
4
+ from datetime import datetime
5
+ from word2number import w2n
6
+ from textblob import Word
7
+ from PIL import Image
8
+ import torch
9
+ import re
10
+
11
+ CHEQUE_PARSER_MODEL = "shivi/donut-base-cheque"
12
+ TASK_PROMPT = "<s_cord-v2>"
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def load_donut_model_and_processor():
16
+ donut_processor = DonutProcessor.from_pretrained(CHEQUE_PARSER_MODEL)
17
+ model = VisionEncoderDecoderModel.from_pretrained(CHEQUE_PARSER_MODEL)
18
+ model.to(device)
19
+ return donut_processor, model
20
+
21
+ def prepare_data_using_processor(donut_processor,image_path):
22
+ ## Pass image through donut processor's feature extractor and retrieve image tensor
23
+ image = load_image(image_path)
24
+ print("type image:", type(image))
25
+ pixel_values = donut_processor(image, return_tensors="pt").pixel_values
26
+ pixel_values = pixel_values.to(device)
27
+
28
+ ## Pass task prompt for document (cheque) parsing task to donut processor's tokenizer and retrieve the input_ids
29
+ decoder_input_ids = donut_processor.tokenizer(TASK_PROMPT, add_special_tokens=False, return_tensors="pt")["input_ids"]
30
+ decoder_input_ids = decoder_input_ids.to(device)
31
+
32
+ return pixel_values, decoder_input_ids
33
+
34
+ def load_image(image_path):
35
+ image = Image.open(image_path).convert("RGB")
36
+ return image
37
+
38
+ def parse_cheque_with_donut(input_image_path):
39
+
40
+ donut_processor, model = load_donut_model_and_processor()
41
+
42
+ cheque_image_tensor, input_for_decoder = prepare_data_using_processor(donut_processor,input_image_path)
43
+
44
+ outputs = model.generate(cheque_image_tensor,
45
+ decoder_input_ids=input_for_decoder,
46
+ max_length=model.decoder.config.max_position_embeddings,
47
+ early_stopping=True,
48
+ pad_token_id=donut_processor.tokenizer.pad_token_id,
49
+ eos_token_id=donut_processor.tokenizer.eos_token_id,
50
+ use_cache=True,
51
+ num_beams=1,
52
+ bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
53
+ return_dict_in_generate=True,
54
+ output_scores=True,)
55
+
56
+ decoded_output_sequence = donut_processor.batch_decode(outputs.sequences)[0]
57
+
58
+ extracted_cheque_details = decoded_output_sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
59
+ ## remove task prompt from token sequence
60
+ cleaned_cheque_details = re.sub(r"<.*?>", "", extracted_cheque_details, count=1).strip()
61
+ ## generate ordered json sequence from output token sequence
62
+ cheque_details_json = donut_processor.token2json(cleaned_cheque_details)
63
+ print("cheque_details_json:",cheque_details_json['cheque_details'])
64
+
65
+ ## extract required fields from predicted json
66
+
67
+ amt_in_words = cheque_details_json['cheque_details'][0]['amt_in_words']
68
+ amt_in_figures = cheque_details_json['cheque_details'][1]['amt_in_figures']
69
+ macthing_amts = match_legal_and_courstesy_amount(amt_in_words,amt_in_figures)
70
+
71
+ payee_name = cheque_details_json['cheque_details'][2]['payee_name']
72
+ cheque_date = '06/05/2022'
73
+ stale_cheque = check_if_cheque_is_stale(cheque_date)
74
+
75
+ return payee_name,amt_in_words,amt_in_figures,cheque_date,macthing_amts,stale_cheque
76
+
77
+ def spell_correction(amt_in_words):
78
+ corrected_amt_in_words =''
79
+ words = amt_in_words.split()
80
+ words = [word.lower() for word in words]
81
+ for word in words:
82
+ word = Word(word)
83
+ corrected_word = word.correct()+' '
84
+ corrected_amt_in_words += corrected_word
85
+ return corrected_amt_in_words
86
+
87
+ def match_legal_and_courstesy_amount(legal_amount,courtesy_amount):
88
+ macthing_amts = False
89
+ corrected_amt_in_words = spell_correction(legal_amount)
90
+ print("corrected_amt_in_words:",corrected_amt_in_words)
91
+ numeric_legal_amt = w2n.word_to_num(corrected_amt_in_words)
92
+ print("numeric_legal_amt:",numeric_legal_amt)
93
+ if int(numeric_legal_amt) == int(courtesy_amount):
94
+ macthing_amts = True
95
+ return macthing_amts
96
+
97
+ def check_if_cheque_is_stale(cheque_issue_date):
98
+ stale_check = False
99
+ current_date = datetime.now().strftime('%d/%m/%Y')
100
+ current_date_ = datetime.strptime(current_date, "%d/%m/%Y")
101
+ cheque_issue_date_ = datetime.strptime(cheque_issue_date, "%d/%m/%Y")
102
+ relative_diff = relativedelta.relativedelta(current_date_, cheque_issue_date_)
103
+ months_difference = (relative_diff.years * 12) + relative_diff.months
104
+ print("months_difference:",months_difference)
105
+ if months_difference > 3:
106
+ stale_check = True
107
+ return stale_check
108
+