Layoutlmv3_v2_space / utilitis.py
ITSAIDI
ghg
bc66c2d
raw
history blame
9.65 kB
import streamlit as st
from paddleocr import PaddleOCR
from PIL import ImageDraw, ImageFont,ImageEnhance
import torch
from transformers import AutoProcessor,LayoutLMv3ForTokenClassification
import numpy as np
import time
model_Hugging_path = "Noureddinesa/Output_LayoutLMv3_v7"
#############################################################################
#############################################################################
def Labels():
labels = ['InvNum', 'InvDate', 'Fourni', 'TTC', 'TVA', 'TT', 'Autre']
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
return id2label, label2id
#############################################################################
#############################################################################
def Paddle():
ocr = PaddleOCR(use_angle_cls=False,lang='fr',rec=False)
return ocr
def processbbox(BBOX, width, height):
bbox = []
bbox.append(BBOX[0][0])
bbox.append(BBOX[0][1])
bbox.append(BBOX[2][0])
bbox.append(BBOX[2][1])
#Scaling
bbox[0]= 1000*bbox[0]/width # X1
bbox[1]= 1000*bbox[1]/height # Y1
bbox[2]= 1000*bbox[2]/width # X2
bbox[3]= 1000*bbox[3]/height # Y2
for i in range(4):
bbox[i] = int(bbox[i])
return bbox
def Preprocess(image):
image_array = np.array(image)
ocr = Paddle()
width, height = image.size
results = ocr.ocr(image_array, cls=True)
results = results[0]
test_dict = {'image': image ,'tokens':[], "bboxes":[]}
for item in results :
bbox = processbbox(item[0], width, height)
test_dict['tokens'].append(item[1][0])
test_dict['bboxes'].append(bbox)
print(test_dict['bboxes'])
print(test_dict['tokens'])
return test_dict
#############################################################################
#############################################################################
def Encode(image):
example = Preprocess(image)
image = example["image"]
words = example["tokens"]
boxes = example["bboxes"]
processor = AutoProcessor.from_pretrained(model_Hugging_path, apply_ocr=False)
encoding = processor(image, words, boxes=boxes,return_offsets_mapping=True,truncation=True, max_length=512, padding="max_length", return_tensors="pt")
offset_mapping = encoding.pop('offset_mapping')
return encoding, offset_mapping,words
#############################################################################
#############################################################################
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
#############################################################################
#############################################################################
def Run_model(image):
encoding,offset_mapping,words = Encode(image)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the fine-tuned model from the hub
model = LayoutLMv3ForTokenClassification.from_pretrained(model_Hugging_path)
model.to(device)
# forward pass
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
id2label, _ = Labels()
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
return true_predictions,true_boxes,words
#############################################################################
#############################################################################
def Get_Json(true_predictions,words):
Results = {}
i = 0
for prd in true_predictions:
if prd in ['InvNum','Fourni', 'InvDate','TT','TTC','TVA']:
#print(i,prd,words[i-1])
Results[prd] = words[i-1]
i+=1
key_mapping = {'InvNum':'Numéro de facture','Fourni':'Fournisseur', 'InvDate':'Date Facture','TT':'Total HT','TTC':'Total TTC','TVA':'TVA'}
Results = {key_mapping.get(key, key): value for key, value in Results.items()}
return Results
#############################################################################
#############################################################################
def Draw(image):
start_time = time.time()
image = enhance_image(image,1.3,1.5)
true_predictions, true_boxes,words = Run_model(image)
draw = ImageDraw.Draw(image)
label2color = {
'InvNum': 'blue',
'InvDate': 'green',
'Fourni': 'orange',
'TTC':'purple',
'TVA': 'magenta',
'TT': 'red',
'Autre': 'black'
}
# Adjust the thickness of the rectangle outline and label text position
rectangle_thickness = 4
label_x_offset = 20
label_y_offset = -30
# Custom font size
custom_font_size = 25
# Load a font with the custom size
font_path = "arial.ttf" # Specify the path to your font file
custom_font = ImageFont.truetype(font_path, custom_font_size)
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = prediction
# Check if the predicted label exists in the label2color dictionary
if predicted_label in label2color:
color = label2color[predicted_label]
else:
color = 'black' # Default color if label is not found
if predicted_label != "Autre":
draw.rectangle(box, outline=color, width=rectangle_thickness)
# Draw text using the custom font and size
draw.rectangle((box[0], box[1]+ label_y_offset,box[2],box[3]+ label_y_offset), fill=color)
draw.text((box[0] + label_x_offset, box[1] + label_y_offset), text=predicted_label, fill='white', font=custom_font)
# Get the Results Json File
Results = Get_Json(true_predictions,words)
end_time = time.time()
execution_time = end_time - start_time
return image,Results,execution_time
#############################################################################
#############################################################################
def Add_Results(data):
# Render the table
for key, value in data.items():
data[key] = st.sidebar.text_input(key, value)
#############################################################################
#############################################################################
def check_if_changed(original_values, updated_values):
for key, value in original_values.items():
if updated_values[key] != value:
return True
return False
#############################################################################
#############################################################################
def Update(Results):
New_results = {}
if "Fournisseur" in Results.keys():
text_fourni = st.sidebar.text_input("Fournisseur", value=Results["Fournisseur"])
New_results["Fournisseur"] = text_fourni
if "Date Facture" in Results.keys():
text_InvDate = st.sidebar.text_input("Date Facture", value=Results["Date Facture"])
New_results["Date Facture"] = text_InvDate
if "Numéro de facture" in Results.keys():
text_InvNum = st.sidebar.text_input("Numéro de facture", value=Results["Numéro de facture"])
New_results["Numéro de facture"] = text_InvNum
if "Total HT" in Results.keys():
text_TT = st.sidebar.text_input("Total HT", value=Results["Total HT"])
New_results["Total HT"] = text_TT
if "TVA" in Results.keys():
text_TVA = st.sidebar.text_input("TVA", value=Results["TVA"])
New_results["TVA"] = text_TVA
if "Total TTC" in Results.keys():
text_TTC = st.sidebar.text_input("TTC", value=Results["Total TTC"])
New_results["Total TTC"] = text_TTC
return New_results
#############################################################################
#############################################################################
def Change_Image(image1,image2):
# Initialize session state
if 'current_image' not in st.session_state:
st.session_state.current_image = 'image1'
# Button to switch between images
if st.sidebar.button('Switcher'):
if st.session_state.current_image == 'image1':
st.session_state.current_image = 'image2'
else:
st.session_state.current_image = 'image1'
# Display the selected image
if st.session_state.current_image == 'image1':
st.image(image1, caption='Output', use_column_width=True)
else:
st.image(image2, caption='Image initiale', use_column_width=True)
#############################################################################
#############################################################################
def enhance_image(image,brightness_factor, contrast_factor):
enhancer = ImageEnhance.Brightness(image)
brightened_image = enhancer.enhance(brightness_factor)
enhancer = ImageEnhance.Contrast(brightened_image)
enhanced_image = enhancer.enhance(contrast_factor)
return enhanced_image