Spaces:
Running
Running
import streamlit as st | |
from PIL import Image, ImageEnhance | |
import torch | |
from torchvision.transforms import functional as F | |
import gc | |
import psutil | |
import copy | |
import xml.etree.ElementTree as ET | |
import numpy as np | |
from pathlib import Path | |
import gdown | |
from modules.OCR import text_prediction, filter_text, mapping_text | |
from modules.utils import class_dict, arrow_dict, object_dict | |
from modules.display import draw_stream | |
from modules.eval import full_prediction | |
from modules.train import get_faster_rcnn_model, get_arrow_model | |
from streamlit_image_comparison import image_comparison | |
def get_memory_usage(): | |
process = psutil.Process() | |
mem_info = process.memory_info() | |
return mem_info.rss / (1024 ** 2) # Return memory usage in MB | |
def clear_memory(): | |
st.session_state.clear() | |
gc.collect() | |
def sidebar():# Sidebar content | |
st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.") | |
st.sidebar.subheader("Instructions:") | |
st.sidebar.text("1. Upload you image") | |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)") | |
st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)") | |
st.sidebar.text("4. Click on 'Launch Prediction'") | |
st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result") | |
st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)") | |
st.sidebar.text("7. You can modify and download \n the result in right format") | |
st.sidebar.subheader("If there is an error, try to:") | |
st.sidebar.text("1. Change the score threshold") | |
st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the center\n of the image") | |
st.sidebar.text("3. Re-Launch the prediction") | |
st.sidebar.subheader("You can close this sidebar") | |
# Function to read XML content from a file | |
def read_xml_file(filepath): | |
""" Read XML content from a file """ | |
with open(filepath, 'r', encoding='utf-8') as file: | |
return file.read() | |
# Function to load the models only once and use session state to keep track of it | |
def load_models(): | |
with st.spinner('Loading model...'): | |
model_object = get_faster_rcnn_model(len(object_dict)) | |
model_arrow = get_arrow_model(len(arrow_dict),2) | |
url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA' | |
url_object = 'https://drive.google.com/uc?id=1b1bqogxqdPS-SnvaOfWJGV1I1qOrTKh5' | |
# Define paths to save models | |
output_arrow = 'model_arrow.pth' | |
output_object = 'model_object.pth' | |
# Download models using gdown | |
if not Path(output_arrow).exists(): | |
# Download models using gdown | |
gdown.download(url_arrow, output_arrow, quiet=False) | |
else: | |
print('Model arrow downloaded from local') | |
if not Path(output_object).exists(): | |
gdown.download(url_object, output_object, quiet=False) | |
else: | |
print('Model object downloaded from local') | |
# Load models | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device)) | |
model_object.load_state_dict(torch.load(output_object, map_location=device)) | |
st.session_state.model_loaded = True | |
st.session_state.model_arrow = model_arrow | |
st.session_state.model_object = model_object | |
return model_object, model_arrow | |
# Function to prepare the image for processing | |
def prepare_image(image, pad=True, new_size=(1333, 1333)): | |
original_size = image.size | |
# Calculate scale to fit the new size while maintaining aspect ratio | |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1]) | |
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale)) | |
# Resize image to new scaled size | |
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0])) | |
if pad: | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(1.0) # Adjust the brightness if necessary | |
# Pad the resized image to make it exactly the desired size | |
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]] | |
image = F.pad(image, padding, fill=200, padding_mode='edge') | |
return image | |
# Function to display various options for image annotation | |
def display_options(image, score_threshold, is_mobile, screen_width): | |
col1, col2, col3, col4, col5 = st.columns(5) | |
with col1: | |
write_class = st.toggle("Write Class", value=True) | |
draw_keypoints = st.toggle("Draw Keypoints", value=True) | |
draw_boxes = st.toggle("Draw Boxes", value=True) | |
with col2: | |
draw_text = st.toggle("Draw Text", value=False) | |
write_text = st.toggle("Write Text", value=False) | |
draw_links = st.toggle("Draw Links", value=False) | |
with col3: | |
write_score = st.toggle("Write Score", value=True) | |
write_idx = st.toggle("Write Index", value=False) | |
with col4: | |
# Define options for the dropdown menu | |
dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))] | |
dropdown_options[0] = 'all' | |
selected_option = st.selectbox("Show class", dropdown_options) | |
# Draw the annotated image with selected options | |
annotated_image = draw_stream( | |
np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred, | |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text, | |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option, | |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True | |
) | |
if is_mobile is True: | |
width = screen_width | |
else: | |
width = screen_width//2 | |
# Display the original and annotated images side by side | |
image_comparison( | |
img1=annotated_image, | |
img2=image, | |
label1="Annotated Image", | |
label2="Original Image", | |
starting_position=99, | |
width=width, | |
) | |
# Function to perform inference on the uploaded image using the loaded models | |
def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5): | |
uploaded_image = prepare_image(image, pad=False) | |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))) | |
# Display original image | |
if 'image_placeholder' not in st.session_state: | |
image_placeholder = st.empty() # Create an empty placeholder | |
if is_mobile is False: | |
width = screen_width | |
if is_mobile is False: | |
width = screen_width//2 | |
image_placeholder.image(uploaded_image, caption='Original Image', width=width) | |
# Prediction | |
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold) | |
# Perform OCR on the uploaded image | |
ocr_results = text_prediction(uploaded_image) | |
# Filter and map OCR results to prediction results | |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.6) | |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh) | |
# Remove the original image display | |
image_placeholder.empty() | |
# Force garbage collection | |
gc.collect() | |
return image, st.session_state.prediction, st.session_state.text_mapping | |
def get_image(uploaded_file): | |
return Image.open(uploaded_file).convert('RGB') |