Spaces:
Runtime error
Runtime error
import cv2 | |
from PIL import Image | |
from ultralyticsplus import YOLO | |
from transformers import pipeline | |
import pandas as pd | |
import numpy as np | |
import easyocr | |
from utils import * | |
INVOICE = ["Numéro de facture", "Date", "Numéro de commande", "Echéance", "Total"] | |
model = YOLO('keremberke/yolov8s-table-extraction') | |
model.overrides['conf'] = 0.25 # NMS confidence threshold | |
model.overrides['iou'] = 0.45 # NMS IoU threshold | |
model.overrides['agnostic_nms'] = False # NMS class-agnostic | |
model.overrides['max_det'] = 1000 # maximum number of detections per image | |
pipe = pipeline("object-detection", model="bilguun/table-transformer-structure-recognition") | |
def detect_tables(image): | |
# image is an np array | |
results = model.predict(image) | |
result = results[0] | |
xyxy = result.boxes.xyxy | |
scores = result.boxes.conf | |
tables = [] | |
for i in range(len(scores)): | |
if scores[i] >= 0.5: | |
table = image[int(xyxy[i,1]):int(xyxy[i,3]), int(xyxy[i,0]):int(xyxy[i,2])] | |
table = Image.fromarray(table) | |
tables.append(table) | |
return tables | |
def insert(el, listt, pos): | |
if not listt: | |
listt.append(el) | |
else: | |
inserted = False | |
for i in range(len(listt)): | |
if el[pos] <= listt[i][pos]: | |
listt.insert(i, el) | |
inserted = True | |
break | |
if not inserted: | |
listt.append(el) | |
def rec_table(table, reader): | |
col_row = pipe(table) | |
cols = [] | |
rows = [] | |
for el in col_row: | |
if el["label"] == 'table column': | |
insert(el["box"], cols, pos = "xmin") | |
elif el["label"] == 'table row': | |
insert(el["box"], rows, pos = "ymin") | |
table = np.array(table) | |
csv = [] | |
for row in rows: | |
temp = [] | |
for col in cols: | |
box = intersection(row, col) | |
cell = table[box['ymin']:box['ymax'], box['xmin']:box['xmax']] | |
res = get_ocr(cell,reader) | |
temp.append(get_input(res)) | |
csv.append(temp) | |
df = pd.DataFrame(csv) | |
return df | |
def intersection(box1, box2): | |
# Extract coordinates of first bounding box | |
x1min, y1min, x1max, y1max = box1['xmin'], box1['ymin'], box1['xmax'], box1['ymax'] | |
# Extract coordinates of second bounding box | |
x2min, y2min, x2max, y2max = box2['xmin'], box2['ymin'], box2['xmax'], box2['ymax'] | |
# Calculate coordinates of intersection | |
xmin = max(x1min, x2min) | |
ymin = max(y1min, y2min) | |
xmax = min(x1max, x2max) | |
ymax = min(y1max, y2max) | |
# Check if there is no intersection | |
if xmin >= xmax or ymin >= ymax: | |
return None | |
# Return the coordinates of the intersection | |
return {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax} | |
# def extract_tables(lang, image): | |
# reader = easyocr.Reader([langs[lang]]) | |
# tables = detect_tables(image) | |
# for i in range(len(tables)): | |
# df = rec_table(tables[i], reader) | |
# df.to_excel(f'table_{i+1}.xlsx', index=False, header=False) | |
def extract_tables(lang, image): | |
reader = easyocr.Reader([langs[lang]]) | |
tables = detect_tables(image) | |
csvs = [] | |
for i in range(len(tables)): | |
df = rec_table(tables[i], reader) | |
csv = df.to_csv(index=False, header=False) | |
csvs.append(csvs) | |
return csvs[0] | |
if __name__ == '__main__': | |
lang = "french" | |
to_be_extracted = INVOICE | |
image_path = "./docs for ocr/invoices/facture.png" | |
image = cv2.imread(image_path) | |
print(image.shape) | |
text_data = extract_data(lang, to_be_extracted, image) | |
print(text_data) | |
# extract_tables(lang, image) # extract tables from the image and download them in excel format to the current directory | |