Spaces:
Sleeping
Sleeping
import streamlit as st | |
from paddleocr import PaddleOCR | |
from PIL import ImageDraw, ImageFont,ImageEnhance | |
import torch | |
from transformers import AutoProcessor,LayoutLMv3ForTokenClassification | |
import numpy as np | |
import time | |
model_Hugging_path = "Noureddinesa/Output_LayoutLMv3_v7" | |
############################################################################# | |
############################################################################# | |
def Labels(): | |
labels = ['InvNum', 'InvDate', 'Fourni', 'TTC', 'TVA', 'TT', 'Autre'] | |
id2label = {v: k for v, k in enumerate(labels)} | |
label2id = {k: v for v, k in enumerate(labels)} | |
return id2label, label2id | |
############################################################################# | |
############################################################################# | |
def Paddle(): | |
ocr = PaddleOCR(use_angle_cls=False,lang='fr',rec=False) | |
return ocr | |
def processbbox(BBOX, width, height): | |
bbox = [] | |
bbox.append(BBOX[0][0]) | |
bbox.append(BBOX[0][1]) | |
bbox.append(BBOX[2][0]) | |
bbox.append(BBOX[2][1]) | |
#Scaling | |
bbox[0]= 1000*bbox[0]/width # X1 | |
bbox[1]= 1000*bbox[1]/height # Y1 | |
bbox[2]= 1000*bbox[2]/width # X2 | |
bbox[3]= 1000*bbox[3]/height # Y2 | |
for i in range(4): | |
bbox[i] = int(bbox[i]) | |
return bbox | |
def Preprocess(image): | |
image_array = np.array(image) | |
ocr = Paddle() | |
width, height = image.size | |
results = ocr.ocr(image_array, cls=True) | |
results = results[0] | |
test_dict = {'image': image ,'tokens':[], "bboxes":[]} | |
for item in results : | |
bbox = processbbox(item[0], width, height) | |
test_dict['tokens'].append(item[1][0]) | |
test_dict['bboxes'].append(bbox) | |
print(test_dict['bboxes']) | |
print(test_dict['tokens']) | |
return test_dict | |
############################################################################# | |
############################################################################# | |
def Encode(image): | |
example = Preprocess(image) | |
image = example["image"] | |
words = example["tokens"] | |
boxes = example["bboxes"] | |
processor = AutoProcessor.from_pretrained(model_Hugging_path, apply_ocr=False) | |
encoding = processor(image, words, boxes=boxes,return_offsets_mapping=True,truncation=True, max_length=512, padding="max_length", return_tensors="pt") | |
offset_mapping = encoding.pop('offset_mapping') | |
return encoding, offset_mapping,words | |
############################################################################# | |
############################################################################# | |
def unnormalize_box(bbox, width, height): | |
return [ | |
width * (bbox[0] / 1000), | |
height * (bbox[1] / 1000), | |
width * (bbox[2] / 1000), | |
height * (bbox[3] / 1000), | |
] | |
############################################################################# | |
############################################################################# | |
def Run_model(image): | |
encoding,offset_mapping,words = Encode(image) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# load the fine-tuned model from the hub | |
model = LayoutLMv3ForTokenClassification.from_pretrained(model_Hugging_path) | |
model.to(device) | |
# forward pass | |
outputs = model(**encoding) | |
predictions = outputs.logits.argmax(-1).squeeze().tolist() | |
token_boxes = encoding.bbox.squeeze().tolist() | |
width, height = image.size | |
id2label, _ = Labels() | |
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0 | |
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]] | |
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]] | |
return true_predictions,true_boxes,words | |
############################################################################# | |
############################################################################# | |
def Get_Json(true_predictions,words): | |
Results = {} | |
i = 0 | |
for prd in true_predictions: | |
if prd in ['InvNum','Fourni', 'InvDate','TT','TTC','TVA']: | |
#print(i,prd,words[i-1]) | |
Results[prd] = words[i-1] | |
i+=1 | |
key_mapping = {'InvNum':'Numéro de facture','Fourni':'Fournisseur', 'InvDate':'Date Facture','TT':'Total HT','TTC':'Total TTC','TVA':'TVA'} | |
Results = {key_mapping.get(key, key): value for key, value in Results.items()} | |
return Results | |
############################################################################# | |
############################################################################# | |
def Draw(image): | |
start_time = time.time() | |
image = enhance_image(image,1.3,1.5) | |
true_predictions, true_boxes,words = Run_model(image) | |
draw = ImageDraw.Draw(image) | |
label2color = { | |
'InvNum': 'blue', | |
'InvDate': 'green', | |
'Fourni': 'orange', | |
'TTC':'purple', | |
'TVA': 'magenta', | |
'TT': 'red', | |
'Autre': 'black' | |
} | |
# Adjust the thickness of the rectangle outline and label text position | |
rectangle_thickness = 4 | |
label_x_offset = 20 | |
label_y_offset = -30 | |
# Custom font size | |
custom_font_size = 25 | |
# Load a font with the custom size | |
font_path = "arial.ttf" # Specify the path to your font file | |
custom_font = ImageFont.truetype(font_path, custom_font_size) | |
for prediction, box in zip(true_predictions, true_boxes): | |
predicted_label = prediction | |
# Check if the predicted label exists in the label2color dictionary | |
if predicted_label in label2color: | |
color = label2color[predicted_label] | |
else: | |
color = 'black' # Default color if label is not found | |
if predicted_label != "Autre": | |
draw.rectangle(box, outline=color, width=rectangle_thickness) | |
# Draw text using the custom font and size | |
draw.rectangle((box[0], box[1]+ label_y_offset,box[2],box[3]+ label_y_offset), fill=color) | |
draw.text((box[0] + label_x_offset, box[1] + label_y_offset), text=predicted_label, fill='white', font=custom_font) | |
# Get the Results Json File | |
Results = Get_Json(true_predictions,words) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
return image,Results,execution_time | |
############################################################################# | |
############################################################################# | |
def Add_Results(data): | |
# Render the table | |
for key, value in data.items(): | |
data[key] = st.sidebar.text_input(key, value) | |
############################################################################# | |
############################################################################# | |
def check_if_changed(original_values, updated_values): | |
for key, value in original_values.items(): | |
if updated_values[key] != value: | |
return True | |
return False | |
############################################################################# | |
############################################################################# | |
def Update(Results): | |
New_results = {} | |
if "Fournisseur" in Results.keys(): | |
text_fourni = st.sidebar.text_input("Fournisseur", value=Results["Fournisseur"]) | |
New_results["Fournisseur"] = text_fourni | |
if "Date Facture" in Results.keys(): | |
text_InvDate = st.sidebar.text_input("Date Facture", value=Results["Date Facture"]) | |
New_results["Date Facture"] = text_InvDate | |
if "Numéro de facture" in Results.keys(): | |
text_InvNum = st.sidebar.text_input("Numéro de facture", value=Results["Numéro de facture"]) | |
New_results["Numéro de facture"] = text_InvNum | |
if "Total HT" in Results.keys(): | |
text_TT = st.sidebar.text_input("Total HT", value=Results["Total HT"]) | |
New_results["Total HT"] = text_TT | |
if "TVA" in Results.keys(): | |
text_TVA = st.sidebar.text_input("TVA", value=Results["TVA"]) | |
New_results["TVA"] = text_TVA | |
if "Total TTC" in Results.keys(): | |
text_TTC = st.sidebar.text_input("TTC", value=Results["Total TTC"]) | |
New_results["Total TTC"] = text_TTC | |
return New_results | |
############################################################################# | |
############################################################################# | |
def Change_Image(image1,image2): | |
# Initialize session state | |
if 'current_image' not in st.session_state: | |
st.session_state.current_image = 'image1' | |
# Button to switch between images | |
if st.sidebar.button('Switcher'): | |
if st.session_state.current_image == 'image1': | |
st.session_state.current_image = 'image2' | |
else: | |
st.session_state.current_image = 'image1' | |
# Display the selected image | |
if st.session_state.current_image == 'image1': | |
st.image(image1, caption='Output', use_column_width=True) | |
else: | |
st.image(image2, caption='Image initiale', use_column_width=True) | |
############################################################################# | |
############################################################################# | |
def enhance_image(image,brightness_factor, contrast_factor): | |
enhancer = ImageEnhance.Brightness(image) | |
brightened_image = enhancer.enhance(brightness_factor) | |
enhancer = ImageEnhance.Contrast(brightened_image) | |
enhanced_image = enhancer.enhance(contrast_factor) | |
return enhanced_image | |