Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import requests | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST") | |
def _group_rectangles(rec): | |
""" | |
Uion intersecting rectangles. | |
Args: | |
rec - list of rectangles in form [x, y, w, h] | |
Return: | |
list of grouped ractangles | |
""" | |
tested = [False for i in range(len(rec))] | |
final = [] | |
i = 0 | |
while i < len(rec): | |
if not tested[i]: | |
j = i+1 | |
while j < len(rec): | |
if not tested[j] and intersect_area(rec[i], rec[j]): | |
rec[i] = union(rec[i], rec[j]) | |
tested[j] = True | |
j = i | |
j += 1 | |
final += [rec[i]] | |
i += 1 | |
return final | |
def process_image(image): | |
bounding_boxes = [] | |
generated_text_list = [] | |
#boundingBoxes_2 = [] | |
#print(np.shape(image)) | |
#print(image) | |
#dim = (28,28) | |
#resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA) | |
#rint(image.astype('uint8')) | |
#cv2.imwrite("image.png",image.astype('uint8'),(28, 28)) | |
#mask = np.zeros(np.shape(image), dtype=np.uint8) | |
thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] | |
#gray = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY) | |
cnts = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
(cnts, _) = contours.sort_contours(cnts, method="left-to-right") | |
dim = (28, 28) | |
for c in cnts: | |
area = cv2.contourArea(c) | |
#print(area) | |
#if area < 120: | |
bounding_boxes.append(cv2.boundingRect(c)) | |
#print("for loop bb: ",bounding_boxes) | |
boundingBoxes_filter = [i for i in bounding_boxes if i != (0 , 0, 128, 128)] | |
boundingBoxes = _group_rectangles(boundingBoxes_filter) | |
#print(boundingBoxes) | |
# | |
#print(boundingBoxes_2) | |
for (x, y, w, h) in boundingBoxes: | |
#print(x,y,w,h) | |
ROI = thresh[y:y+h, x:x+w] | |
ROI2 = cv2.bitwise_not(ROI) | |
borderoutput = cv2.copyMakeBorder(ROI2, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
resized = cv2.resize(borderoutput, dim, interpolation = cv2.INTER_AREA) | |
cv2.imwrite('ROI_{}.png'.format(x), resized) | |
#imageinv = cv2.bitwise_not(resized) | |
img = Image.fromarray(resized.astype('uint8')).convert("RGB") | |
pixel_values = processor(img, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
#print(generated_text) | |
generated_text_list.append(generated_text) | |
#img = Image.fromarray(image.astype('uint8')).convert("RGB") | |
#img = Image.open("image.png").convert("RGB") | |
#print(img) | |
# prepare image | |
#pixel_values = processor(img, return_tensors="pt").pixel_values | |
# generate (no beam search) | |
#generated_ids = model.generate(pixel_values) | |
# decode | |
#generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return ''.join(generated_text_list) | |
#return generated_text | |
title = "Interactive demo: Single Digits MNIST" | |
description = "Aico - University Utrecht" | |
iface = gr.Interface(fn=process_image, | |
inputs="sketchpad", | |
outputs="label", | |
title = title, | |
description = description) | |
iface.launch(debug=True) |