Spaces:
Runtime error
Runtime error
import os | |
os.environ["USE_TORCH"] = "1" | |
os.environ["USE_TF"] = "0" | |
import torch | |
from torch.utils.data.dataloader import DataLoader | |
from builder import DocumentBuilder | |
from trocr import IAMDataset, device, get_processor_model | |
from doctr.utils.visualization import visualize_page | |
from doctr.models.predictor.base import _OCRPredictor | |
from doctr.models.detection.predictor import DetectionPredictor | |
from doctr.models.preprocessor import PreProcessor | |
from doctr.models import db_resnet50, db_mobilenet_v3_large | |
from doctr.io import DocumentFile | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"] | |
RECO_ARCHS = ["microsoft/trocr-large-printed", "microsoft/trocr-large-stage1", "microsoft/trocr-large-handwritten"] | |
def main(): | |
# Wide mode | |
st.set_page_config(layout="wide") | |
# Designing the interface | |
st.title("docTR + TrOCR") | |
# For newline | |
st.write('\n') | |
# | |
st.write('For Detection DocTR: https://github.com/mindee/doctr') | |
# For newline | |
st.write('\n') | |
st.write('For Recognition TrOCR: https://github.com/microsoft/unilm/tree/master/trocr') | |
# For newline | |
st.write('\n') | |
st.write('Any Issue please dm') | |
# For newline | |
st.write('\n') | |
# Instructions | |
st.markdown( | |
"*Hint: click on the top-right corner of an image to enlarge it!*") | |
# Set the columns | |
cols = st.columns((1, 1, 1)) | |
cols[0].subheader("Input page") | |
cols[1].subheader("Segmentation heatmap") | |
# Sidebar | |
# File selection | |
st.sidebar.title("Document selection") | |
# Disabling warning | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
# Choose your own image | |
uploaded_file = st.sidebar.file_uploader( | |
"Upload files", type=['pdf', 'png', 'jpeg', 'jpg']) | |
if uploaded_file is not None: | |
if uploaded_file.name.endswith('.pdf'): | |
doc = DocumentFile.from_pdf(uploaded_file.read()).as_images() | |
else: | |
doc = DocumentFile.from_images(uploaded_file.read()) | |
page_idx = st.sidebar.selectbox( | |
"Page selection", [idx + 1 for idx in range(len(doc))]) - 1 | |
cols[0].image(doc[page_idx]) | |
# Model selection | |
st.sidebar.title("Model selection") | |
det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS) | |
rec_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS) | |
# For newline | |
st.sidebar.write('\n') | |
if st.sidebar.button("Analyze page"): | |
if uploaded_file is None: | |
st.sidebar.write("Please upload a document") | |
else: | |
with st.spinner('Loading model...'): | |
if det_arch == "db_resnet50": | |
det_model = db_resnet50(pretrained=True) | |
else: | |
det_model = db_mobilenet_v3_large(pretrained=True) | |
det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model) | |
rec_processor, rec_model = get_processor_model(rec_arch) | |
with st.spinner('Analyzing...'): | |
# Forward the image to the model | |
processed_batches = det_predictor.pre_processor([doc[page_idx]]) | |
out = det_predictor.model(processed_batches[0], return_model_output=True) | |
seg_map = out["out_map"] | |
seg_map = torch.squeeze(seg_map[0, ...], axis=0) | |
seg_map = cv2.resize(seg_map.detach().numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), | |
interpolation=cv2.INTER_LINEAR) | |
# Plot the raw heatmap | |
fig, ax = plt.subplots() | |
ax.imshow(seg_map) | |
ax.axis('off') | |
cols[1].pyplot(fig) | |
# Plot OCR output | |
# Localize text elements | |
loc_preds = out["preds"] | |
# Check whether crop mode should be switched to channels first | |
channels_last = len(doc) == 0 or isinstance(doc[0], np.ndarray) | |
# Crop images | |
crops, loc_preds = _OCRPredictor._prepare_crops( | |
doc, loc_preds, channels_last=channels_last, assume_straight_pages=True | |
) | |
test_dataset = IAMDataset(crops[0], rec_processor) | |
test_dataloader = DataLoader(test_dataset, batch_size=16) | |
text = [] | |
with torch.no_grad(): | |
for batch in test_dataloader: | |
pixel_values = batch["pixel_values"].to(device) | |
generated_ids = rec_model.generate(pixel_values) | |
generated_text = rec_processor.batch_decode( | |
generated_ids, skip_special_tokens=True) | |
text.extend(generated_text) | |
boxes, text_preds = _OCRPredictor._process_predictions( | |
loc_preds, text) | |
doc_builder = DocumentBuilder() | |
out = doc_builder( | |
boxes, | |
text_preds, | |
[ | |
# type: ignore[misc] | |
page.shape[:2] if channels_last else page.shape[-2:] | |
for page in [doc[page_idx]] | |
] | |
) | |
for df in out: | |
st.markdown("text") | |
st.write(" ".join(df["word"].to_list())) | |
st.write('\n') | |
st.markdown("\n Dataframe Output- similar to Tesseract:") | |
st.dataframe(df) | |
if __name__ == '__main__': | |
main() |