# %% import spaces import matplotlib.style from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import torch from pathlib import Path from PIL import Image from PIL import ImageDraw from IPython.display import display import numpy as np from collections import namedtuple import sys print(sys.version_info) #%% class Florence: def __init__(self, model_id:str, hack=False): if hack: return self.model = ( AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype="auto" ) .eval() .cuda() ) self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) self.model_id = model_id def run(self, img:Image, task_prompt:str, extra_text:str|None=None): model, processor = self.model, self.processor prompt = task_prompt + (extra_text if extra_text else "") inputs = processor(text=prompt, images=img, return_tensors="pt").to( "cuda", torch.float16 ) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(img.width, img.height), ) return parsed_answer def model_init(): fl = Florence("microsoft/Florence-2-large", hack=False) fl_ft = Florence("microsoft/Florence-2-large-ft", hack=False) return fl, fl_ft # florence-2 tasks TASK_OD = "" TASK_SEGMENTATION = '' TASK_CAPTION = "" TASK_OCR = "" TASK_GROUNDING = "" #%% from skimage.measure import LineModelND, ransac def get_polygons(fl:Florence, img2:Image, prompt): parsed_answer = fl.run(img2, TASK_SEGMENTATION, prompt) assert len(parsed_answer) == 1 k,v = parsed_answer.popitem() assert 'polygons' in v assert len(v['polygons']) == 1 polygons = v['polygons'][0] return polygons def get_ocr(fl:Florence, img2:Image): parsed_answer = fl.run(img2, TASK_OCR) assert len(parsed_answer)==1 k,v = parsed_answer.popitem() return v imgs = list(Path('images/other').glob('*.jpg')) meter_labels = list(map(str, range(0, 600, 100))) def read_meter(img, fl:Florence, fl_ft:Florence): if isinstance(img, str) or isinstance(img, Path): print(img) img = Image.open(img) red_polygons = get_polygons(fl, img, 'red triangle pointer') # draw the rectangle draw = ImageDraw.Draw(img) ocr_text = {} ocr1 = get_ocr(fl, img) ocr2 = get_ocr(fl_ft, img) quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes'] labels = ocr1['labels']+ocr2['labels'] for quad_box, label in zip(quad_boxes, labels): if label in meter_labels: ocr_text[int(label)] = quad_box for label, quad_box in ocr_text.items(): draw.polygon(quad_box, outline='green', width=3) draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls') text_centers = np.array(list(ocr_text.values())).reshape(-1, 4, 2).mean(axis=1) lm = LineModelND() lm.estimate(text_centers) orign, direction = lm.params # project text centers to the line text_centers_shifted = text_centers - orign text_centers_norm = text_centers_shifted @ direction lm2 = LineModelND() I = np.array(list(ocr_text.keys())) L = text_centers_norm data = np.stack([I, L], axis=1) lm2.estimate(data) ls = lm2.predict(list(range(0, 600, 100)))[:, 1] x0, y0 = ls[0] * direction + orign x1, y1 = ls[-1] * direction + orign draw.line((x0, y0, x1, y1), fill='yellow', width=3) for l in ls: x, y = l * direction + orign draw.ellipse((x-5, y-5, x+5, y+5), outline='yellow', width=3) red_coords = np.concatenate(red_polygons).reshape(-1, 2) red_shifted = red_coords - orign red_norm = red_shifted @ direction red_l = red_norm.mean() red_i = np.clip(lm2.predict_x([red_l]), 0, 500) red_l = lm2.predict_y(red_i)[0] red_center = red_l * direction + orign draw.ellipse((red_center[0]-5, red_center[1]-5, red_center[0]+5, red_center[1]+5), outline='red', width=3) return red_i[0], img @spaces.GPU def main(): fl, fl_ft = model_init() for img_fn in imgs: print(img_fn) img = Image.open(img_fn) red_i, img2 = read_meater(img, fl, fl_ft) print(red_i) display(img2) if __name__ == '__main__': main() #%%