Spaces:
Runtime error
Runtime error
import io | |
import pandas as pd | |
import plotly.express as px | |
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from easyocr import Reader | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import ( | |
LayoutLMv3FeatureExtractor, | |
LayoutLMv3TokenizerFast, | |
LayoutLMv3Processor, | |
LayoutLMv3ForSequenceClassification | |
) | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base" | |
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification" | |
def creat_bounding_box(bbox_data, width_scale: float, height_scale: float): | |
xs = [] | |
ys = [] | |
for x, y in bbox_data: | |
xs.append(x) | |
ys.append(y) | |
left = int(min(xs) * width_scale) | |
top = int(min(ys) * height_scale) | |
right = int(max(xs) * width_scale) | |
bottom = int(max(ys) * height_scale) | |
return [left, top, right, bottom] | |
def create_ocr_reader(): | |
return Reader(["en"]) | |
def create_processor(): | |
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) | |
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME) | |
return LayoutLMv3Processor(feature_extractor, tokenizer) | |
def create_model(): | |
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME) | |
return model.eval().to(DEVICE) | |
def predict( | |
image: Image, | |
reader: Reader, | |
processor: LayoutLMv3Processor, | |
model: LayoutLMv3ForSequenceClassification | |
): | |
ocr_result = reader.readtext(image) | |
width, height = image.size | |
width_scale = 1000 / width | |
height_scale = 1000 / height | |
words = [] | |
boxes = [] | |
for bbox, word, confidence in ocr_result: | |
words.append(word) | |
boxes.append(creat_bounding_box(bbox, width_scale, height_scale)) | |
encoding = processor( | |
image, | |
words, | |
boxes=boxes, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
) | |
with torch.inference_mode(): | |
output = model( | |
input_ids=encoding["input_ids"].to(DEVICE), | |
attention_mask=encoding["attention_mask"].to(DEVICE), | |
bbox=encoding["bbox"].to(DEVICE), | |
pixel_values=encoding["pixel_values"].to(DEVICE), | |
) | |
logits = output.logits | |
predicted_class = logits.argmax() | |
probabilities = F.softmax(logits, dim=-1).flatten().tolist() | |
return predicted_class.detach().item(), probabilities | |
reader = create_ocr_reader() | |
processor = create_processor() | |
model = create_model() | |
uploaded_file = st.file_uploader("Upload Document Image", ["jpg", "png"]) | |
if uploaded_file is not None: | |
bytes_data = io.BytesIO(uploaded_file.getvalue()) | |
image = Image.open(bytes_data) | |
st.image(image, "Your Document") | |
predicted_class, probabilities = predict(image, reader, processor, model) | |
predicted_label = model.config.id2label[predicted_class] | |
st.markdown(f"Predicted document type: **{predicted_label}**") | |
df_predictions = pd.DataFrame( | |
{"Document": list(model.config.id2label.values()), "Confidence": probabilities} | |
) | |
fig = px.bar(df_predictions, x="Document", y="Confidence") | |
st.plotly_chart(fig, use_container_width=True) | |