|
|
|
|
|
|
|
|
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
|
|
from doctr.io import DocumentFile |
|
from doctr.utils.visualization import visualize_page |
|
|
|
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor |
|
|
|
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def main(det_archs, reco_archs): |
|
"""Build a streamlit layout""" |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.title("docTR: Document Text Recognition") |
|
|
|
st.write("\n") |
|
|
|
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*") |
|
|
|
cols = st.columns((1, 1, 1, 1)) |
|
cols[0].subheader("Input page") |
|
cols[1].subheader("Segmentation heatmap") |
|
cols[2].subheader("OCR output") |
|
cols[3].subheader("Page reconstitution") |
|
|
|
|
|
|
|
st.sidebar.title("Document selection") |
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"]) |
|
if uploaded_file is not None: |
|
if uploaded_file.name.endswith(".pdf"): |
|
doc = DocumentFile.from_pdf(uploaded_file.read()) |
|
else: |
|
doc = DocumentFile.from_images(uploaded_file.read()) |
|
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1 |
|
page = doc[page_idx] |
|
cols[0].image(page) |
|
|
|
|
|
st.sidebar.title("Model selection") |
|
det_arch = st.sidebar.selectbox("Text detection model", det_archs) |
|
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs) |
|
|
|
|
|
st.sidebar.write("\n") |
|
|
|
st.sidebar.title("Parameters") |
|
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True) |
|
st.sidebar.write("\n") |
|
|
|
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False) |
|
st.sidebar.write("\n") |
|
|
|
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) |
|
st.sidebar.write("\n") |
|
|
|
if st.sidebar.button("Analyze page"): |
|
if uploaded_file is None: |
|
st.sidebar.write("Please upload a document") |
|
|
|
else: |
|
with st.spinner("Loading model..."): |
|
predictor = load_predictor( |
|
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device |
|
) |
|
|
|
with st.spinner("Analyzing..."): |
|
|
|
seg_map = forward_image(predictor, page, forward_device) |
|
seg_map = np.squeeze(seg_map) |
|
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
ax.imshow(seg_map) |
|
ax.axis("off") |
|
cols[1].pyplot(fig) |
|
|
|
|
|
out = predictor([page]) |
|
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False) |
|
cols[2].pyplot(fig) |
|
|
|
|
|
page_export = out.pages[0].export() |
|
if assume_straight_pages or (not assume_straight_pages and straighten_pages): |
|
img = out.pages[0].synthesize() |
|
cols[3].image(img, clamp=True) |
|
|
|
|
|
st.markdown("\nHere are your analysis results in JSON format:") |
|
st.json(page_export, expanded=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
main(DET_ARCHS, RECO_ARCHS) |
|
|