''' M-LSD Copyright 2021-present NAVER Corp. Apache License v2.0 ''' # for demo import os from flask import Flask, request, session, json, Response, render_template, abort, send_from_directory import requests from urllib.request import urlopen from io import BytesIO import uuid import cv2 import time import argparse # for tflite import numpy as np from PIL import Image import tensorflow as tf # for square detector from utils import pred_squares os.environ['CUDA_VISIBLE_DEVICES'] = '' # CPU mode # flask app = Flask(__name__) logger = app.logger logger.info('init demo app') # config parser = argparse.ArgumentParser() ## model parameters parser.add_argument('--tflite_path', default='./tflite_models/M-LSD_512_large_fp16.tflite', type=str) parser.add_argument('--input_size', default=512, type=int, help='The size of input images.') ## LSD parameter parser.add_argument('--score_thr', default=0.10, type=float, help='Discard center points when the score < score_thr.') ## intersection point parameters parser.add_argument('--outside_ratio', default=0.10, type=float, help='''Discard an intersection point when it is located outside a line segment farther than line_length * outside_ratio.''') parser.add_argument('--inside_ratio', default=0.50, type=float, help='''Discard an intersection point when it is located inside a line segment farther than line_length * inside_ratio.''') ## ranking boxes parameters parser.add_argument('--w_overlap', default=0.0, type=float, help='''When increasing w_overlap, the final box tends to overlap with the detected line segments as much as possible.''') parser.add_argument('--w_degree', default=1.14, type=float, help='''When increasing w_degree, the final box tends to be a parallel quadrilateral with reference to the angle of the box.''') parser.add_argument('--w_length', default=0.03, type=float, help='''When increasing w_length, the final box tends to be a parallel quadrilateral with reference to the length of the box.''') parser.add_argument('--w_area', default=1.84, type=float, help='When increasing w_area, the final box tends to be the largest one out of candidates.') parser.add_argument('--w_center', default=1.46, type=float, help='When increasing w_center, the final box tends to be located in the center of input image.') ## flask demo parameter parser.add_argument('--port', default=5000, type=int, help='flask demo will be running on http://0.0.0.0:port/') class model_graph: def __init__(self, args): self.interpreter, self.input_details, self.output_details = self.load_tflite(args.tflite_path) self.params = {'score': args.score_thr,'outside_ratio': args.outside_ratio,'inside_ratio': args.inside_ratio, 'w_overlap': args.w_overlap,'w_degree': args.w_degree,'w_length': args.w_length, 'w_area': args.w_area,'w_center': args.w_center} self.args = args def load_tflite(self, tflite_path): interpreter = tf.lite.Interpreter(model_path=tflite_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() return interpreter, input_details, output_details def pred_tflite(self, image): segments, squares, score_array, inter_points = pred_squares(image, self.interpreter, self.input_details, self.output_details, [self.args.input_size, self.args.input_size], params=self.params) output = {} output['segments'] = segments output['squares'] = squares output['scores'] = score_array output['inter_points'] = inter_points return output def read_image(self, image_url): response = requests.get(image_url, stream=True) image = np.asarray(Image.open(BytesIO(response.content)).convert('RGB')) max_len = 1024 h, w, _ = image.shape org_shape = [h, w] max_idx = np.argmax(org_shape) max_val = org_shape[max_idx] if max_val > max_len: min_idx = (max_idx + 1) % 2 ratio = max_len / max_val new_min = org_shape[min_idx] * ratio new_shape = [0, 0] new_shape[max_idx] = 1024 new_shape[min_idx] = new_min image = cv2.resize(image, (int(new_shape[1]), int(new_shape[0])), interpolation=cv2.INTER_AREA) return image def init_resize_image(self, im, maximum_size=1024): h, w, _ = im.shape size = [h, w] max_arg = np.argmax(size) max_len = size[max_arg] min_arg = max_arg - 1 min_len = size[min_arg] if max_len < maximum_size: return im else: ratio = maximum_size / max_len max_len = max_len * ratio min_len = min_len * ratio size[max_arg] = int(max_len) size[min_arg] = int(min_len) im = cv2.resize(im, (size[1], size[0]), interpolation = cv2.INTER_AREA) return im def decode_image(self, session_id, rawimg): dirpath = os.path.join('static/results', session_id) if not os.path.exists(dirpath): os.makedirs(dirpath) save_path = os.path.join(dirpath, 'input.png') input_image_url = os.path.join(dirpath, 'input.png') img = cv2.imdecode(np.frombuffer(rawimg, dtype='uint8'), 1)[:,:,::-1] img = self.init_resize_image(img) cv2.imwrite(save_path, img[:,:,::-1]) return img, input_image_url def draw_output(self, image, output, save_path='test.png'): color_dict = {'red': [255, 0, 0], 'green': [0, 255, 0], 'blue': [0, 0, 255], 'cyan': [0, 255, 255], 'black': [0, 0, 0], 'yellow': [255, 255, 0], 'dark_yellow': [200, 200, 0]} line_image = image.copy() square_image = image.copy() square_candidate_image = image.copy() line_thick = 5 # output > line array for line in output['segments']: x_start, y_start, x_end, y_end = [int(val) for val in line] cv2.line(line_image, (x_start, y_start), (x_end, y_end), color_dict['red'], line_thick) inter_image = line_image.copy() for pt in output['inter_points']: x, y = [int(val) for val in pt] cv2.circle(inter_image, (x, y), 10, color_dict['blue'], -1) for square in output['squares']: cv2.polylines(square_candidate_image, [square.reshape([-1, 1, 2])], True, color_dict['dark_yellow'], line_thick) for square in output['squares'][0:1]: cv2.polylines(square_image, [square.reshape([-1, 1, 2])], True, color_dict['yellow'], line_thick) for pt in square: cv2.circle(square_image, (int(pt[0]), int(pt[1])), 10, color_dict['cyan'], -1) ''' square image | square candidates image inter image | line image ''' output_image = self.init_resize_image(square_image, 512) output_image = np.concatenate([output_image, self.init_resize_image(square_candidate_image, 512)], axis=1) output_image_tmp = np.concatenate([self.init_resize_image(inter_image, 512), self.init_resize_image(line_image, 512)], axis=1) output_image = np.concatenate([output_image, output_image_tmp], axis=0) cv2.imwrite(save_path, output_image[:,:,::-1]) return output_image def save_output(self, session_id, input_image_url, image, output): dirpath = os.path.join('static/results', session_id) if not os.path.exists(dirpath): os.makedirs(dirpath) save_path = os.path.join(dirpath, 'output.png') self.draw_output(image, output, save_path=save_path) output_image_url = os.path.join(dirpath, 'output.png') rst = {} rst['input_image_url'] = input_image_url rst['session_id'] = session_id rst['output_image_url'] = output_image_url with open(os.path.join(dirpath, 'results.json'), 'w') as f: json.dump(rst, f) def init_worker(args): global model model = model_graph(args) @app.route('/') def index(): return render_template('index_scan.html', session_id='dummy_session_id') @app.route('/', methods=['POST']) def index_post(): request_start = time.time() configs = request.form session_id = str(uuid.uuid1()) image_url = configs['image_url'] # image_url if len(image_url) == 0: bio = BytesIO() request.files['image'].save(bio) rawimg = bio.getvalue() image, image_url = model.decode_image(session_id, rawimg) else: image = model.read_image(image_url) output = model.pred_tflite(image) model.save_output(session_id, image_url, image, output) return render_template('index_scan.html', session_id=session_id) @app.route('/favicon.ico') def favicon(): return send_from_directory(os.path.join(app.root_path, 'static'), 'favicon.ico', mimetype='image/vnd.microsoft.icon') if __name__ == '__main__': args = parser.parse_args() init_worker(args) app.run(host='0.0.0.0', port=args.port)