FactureOCR / invoice.py
Soufiane
added table extraction
8a6a4ae
import cv2
from PIL import Image
from ultralyticsplus import YOLO
from transformers import pipeline
import pandas as pd
import numpy as np
import easyocr
from utils import *
INVOICE = ["Numéro de facture", "Date", "Numéro de commande", "Echéance", "Total"]
model = YOLO('keremberke/yolov8s-table-extraction')
model.overrides['conf'] = 0.25 # NMS confidence threshold
model.overrides['iou'] = 0.45 # NMS IoU threshold
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000 # maximum number of detections per image
pipe = pipeline("object-detection", model="bilguun/table-transformer-structure-recognition")
def detect_tables(image):
# image is an np array
results = model.predict(image)
result = results[0]
xyxy = result.boxes.xyxy
scores = result.boxes.conf
tables = []
for i in range(len(scores)):
if scores[i] >= 0.5:
table = image[int(xyxy[i,1]):int(xyxy[i,3]), int(xyxy[i,0]):int(xyxy[i,2])]
table = Image.fromarray(table)
tables.append(table)
return tables
def insert(el, listt, pos):
if not listt:
listt.append(el)
else:
inserted = False
for i in range(len(listt)):
if el[pos] <= listt[i][pos]:
listt.insert(i, el)
inserted = True
break
if not inserted:
listt.append(el)
def rec_table(table, reader):
col_row = pipe(table)
cols = []
rows = []
for el in col_row:
if el["label"] == 'table column':
insert(el["box"], cols, pos = "xmin")
elif el["label"] == 'table row':
insert(el["box"], rows, pos = "ymin")
table = np.array(table)
csv = []
for row in rows:
temp = []
for col in cols:
box = intersection(row, col)
cell = table[box['ymin']:box['ymax'], box['xmin']:box['xmax']]
res = get_ocr(cell,reader)
temp.append(get_input(res))
csv.append(temp)
df = pd.DataFrame(csv)
return df
def intersection(box1, box2):
# Extract coordinates of first bounding box
x1min, y1min, x1max, y1max = box1['xmin'], box1['ymin'], box1['xmax'], box1['ymax']
# Extract coordinates of second bounding box
x2min, y2min, x2max, y2max = box2['xmin'], box2['ymin'], box2['xmax'], box2['ymax']
# Calculate coordinates of intersection
xmin = max(x1min, x2min)
ymin = max(y1min, y2min)
xmax = min(x1max, x2max)
ymax = min(y1max, y2max)
# Check if there is no intersection
if xmin >= xmax or ymin >= ymax:
return None
# Return the coordinates of the intersection
return {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
# def extract_tables(lang, image):
# reader = easyocr.Reader([langs[lang]])
# tables = detect_tables(image)
# for i in range(len(tables)):
# df = rec_table(tables[i], reader)
# df.to_excel(f'table_{i+1}.xlsx', index=False, header=False)
def extract_tables(lang, image):
reader = easyocr.Reader([langs[lang]])
tables = detect_tables(image)
csvs = []
for i in range(len(tables)):
df = rec_table(tables[i], reader)
csv = df.to_csv(index=False, header=False)
csvs.append(csvs)
return csvs[0]
if __name__ == '__main__':
lang = "french"
to_be_extracted = INVOICE
image_path = "./docs for ocr/invoices/facture.png"
image = cv2.imread(image_path)
print(image.shape)
text_data = extract_data(lang, to_be_extracted, image)
print(text_data)
# extract_tables(lang, image) # extract tables from the image and download them in excel format to the current directory