|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
from doctr.models import ocr_predictor |
|
from doctr.models.predictor import OCRPredictor |
|
|
|
DET_ARCHS = [ |
|
"db_resnet50", |
|
"db_resnet34", |
|
"db_mobilenet_v3_large", |
|
"linknet_resnet18", |
|
"linknet_resnet34", |
|
"linknet_resnet50", |
|
] |
|
RECO_ARCHS = [ |
|
"crnn_vgg16_bn", |
|
"crnn_mobilenet_v3_small", |
|
"crnn_mobilenet_v3_large", |
|
"master", |
|
"sar_resnet31", |
|
"vitstr_small", |
|
"vitstr_base", |
|
"parseq", |
|
] |
|
|
|
|
|
def load_predictor( |
|
det_arch: str, |
|
reco_arch: str, |
|
assume_straight_pages: bool, |
|
straighten_pages: bool, |
|
bin_thresh: float, |
|
device: torch.device, |
|
) -> OCRPredictor: |
|
"""Load a predictor from doctr.models |
|
|
|
Args: |
|
---- |
|
det_arch: detection architecture |
|
reco_arch: recognition architecture |
|
assume_straight_pages: whether to assume straight pages or not |
|
straighten_pages: whether to straighten rotated pages or not |
|
bin_thresh: binarization threshold for the segmentation map |
|
device: torch.device, the device to load the predictor on |
|
|
|
Returns: |
|
------- |
|
instance of OCRPredictor |
|
""" |
|
predictor = ocr_predictor( |
|
det_arch, |
|
reco_arch, |
|
pretrained=True, |
|
assume_straight_pages=assume_straight_pages, |
|
straighten_pages=straighten_pages, |
|
export_as_straight_boxes=straighten_pages, |
|
detect_orientation=not assume_straight_pages, |
|
).to(device) |
|
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh |
|
return predictor |
|
|
|
|
|
def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray: |
|
"""Forward an image through the predictor |
|
|
|
Args: |
|
---- |
|
predictor: instance of OCRPredictor |
|
image: image to process |
|
device: torch.device, the device to process the image on |
|
|
|
Returns: |
|
------- |
|
segmentation map |
|
""" |
|
with torch.no_grad(): |
|
processed_batches = predictor.det_predictor.pre_processor([image]) |
|
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True) |
|
seg_map = out["out_map"].to("cpu").numpy() |
|
|
|
return seg_map |
|
|