File size: 16,165 Bytes
ba529ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6261698
ba529ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05496fd
ba529ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a68f143
ba529ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# %%
import matplotlib.style
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import pickle
import torch
from pathlib import Path
from PIL import Image
from PIL import ImageDraw
import numpy as np
from collections import namedtuple
from logging import getLogger
logger = getLogger(__name__)
# %%
class Florence:
    def __init__(self, model_id:str, hack=False):
        if hack:
            return
        self.model = (
            AutoModelForCausalLM.from_pretrained(
                model_id, trust_remote_code=True, torch_dtype="auto"
            )
            .eval()
            .cuda()
        )
        self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        self.model_id = model_id
    def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
        logger.debug(f"run {task_prompt} {extra_text}")
        model, processor = self.model, self.processor
        prompt = task_prompt + (extra_text if extra_text else "")
        inputs = processor(text=prompt, images=img, return_tensors="pt").to(
            "cuda", torch.float16
        )
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
            #temperature=0.1,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(img.width, img.height),
        )
        return parsed_answer
def model_init(hack=False):
    fl = Florence("microsoft/Florence-2-large", hack=hack)
    fl_ft = Florence("microsoft/Florence-2-large-ft", hack=hack)
    return fl, fl_ft
#%%
# florence-2 tasks
TASK_OD = "<OD>"
TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
TASK_OCR = "<OCR_WITH_REGION>"
TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
#%%
AIModelResult = namedtuple('AIModelResult', 
                         ['img', 'img2', 'meter_bbox', 'needle_polygons', 'circle_polygons', 'ocr1', 'ocr2'])
cached_results:dict[str, AIModelResult] = {}

#%%
def get_meter_bbox(fl:Florence, img:Image):
    task_prompt, extra_text = TASK_GROUNDING, "a circular meter with white background"
    parsed_answer = fl.run(img, task_prompt, extra_text)
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'bboxes' in v
    assert 'labels' in v
    assert len(v['bboxes']) == 1
    assert len(v['labels']) == 1
    assert v['labels'][0] == 'a circular meter'
    bbox = v['bboxes'][0]
    return bbox

def get_circles(fl:Florence, img2:Image, polygons:list):    
    img3 = Image.new('L', img2.size, color = 'black')
    draw = ImageDraw.Draw(img3)
    for polygon in polygons:
        draw.polygon(polygon, outline='white', width=3, fill='white')
    img2a = np.where(np.array(img3)[:,:,None]>0,  np.array(img2), 255)
    img4 = Image.fromarray(img2a)
    parsed_answer = fl.run(img4, TASK_SEGMENTATION, "a circle")    
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'polygons' in v
    assert len(v['polygons']) == 1
    return v['polygons'][0]

def get_needle_polygons(fl:Florence, img2:Image):
    parsed_answer = fl.run(img2, TASK_SEGMENTATION, "the long narrow black needle hand pass through the center of the cicular meter")
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'polygons' in v
    assert len(v['polygons']) == 1
    needle_polygons = v['polygons'][0]
    return needle_polygons

def get_ocr(fl:Florence, img2:Image):
    parsed_answer = fl.run(img2, TASK_OCR)
    assert len(parsed_answer)==1
    k,v = parsed_answer.popitem()
    return v

def get_ai_model_result(img:Image.Image|Path|str, fl:Florence, fl_ft:Florence):
    if isinstance(img, Path):
        key = img.parts[-1]
    elif isinstance(img, str):
        key = img.split('/')[-1]
    else:
        key = None
    if key is not None and key in cached_results:
        return cached_results[key]
    if isinstance(img, (Path, str)):        
        img = Image.open(img)
    meter_bbox = get_meter_bbox(fl, img)
    img2 = img.crop(meter_bbox)
    needle_polygons = get_needle_polygons(fl, img2)
    result = AIModelResult(img, img2, meter_bbox, needle_polygons,
                            get_circles(fl, img2, needle_polygons),
                            get_ocr(fl, img2),
                            get_ocr(fl_ft, img2)
                            )
    if key is not None:
        cached_results[key] = result
    return result
#%%
from skimage.measure import regionprops
from skimage.measure import EllipseModel
from skimage.draw import ellipse_perimeter
def get_regionprops(polygons:list) -> regionprops:
    coords = np.concatenate(polygons).reshape(-1, 2)
    size = tuple( (coords.max(axis=0)+2).astype('int') )
    img = Image.new('L', size, color = 'black')
    # draw circle polygon
    draw = ImageDraw.Draw(img)
    for polygon in polygons:
        draw.polygon(polygon, outline='white', width=1, fill='white')
    # use skimage to find the mass center of the circle
    circle_imga = (np.array(img)>0).astype(np.uint8)
    property = regionprops(circle_imga)[0]
    return property
def estimate_ellipse(coords, enlarge_factor=1.0):
    em = EllipseModel()
    em.estimate(coords[:, ::-1])
    y, x, a, b, theta = em.params
    a, b = a*enlarge_factor, b*enlarge_factor
    em_params = np.round([y,x, a, b]).astype('int')
    c, r = ellipse_perimeter(*em_params, orientation=-theta)
    return em_params, theta, (c, r)
def estimate_line(coords):
    lm = LineModelND()
    lm.estimate(coords)
    return lm.params
#%%
#%%
from matplotlib import pyplot as plt
import matplotlib
from skimage.measure import LineModelND, ransac
matplotlib.style.use('dark_background')
def rotate_theta(theta):
    return ((theta + 3*np.pi/2)%(2*np.pi))/(2*np.pi)*360
kg_cm2_labels = list(map(str, [1,3,5,7,9,11]))
psi_labels = list(map(str, range(20, 180, 20)))

# lousy decoupling 
MeterResult = namedtuple('MeterResult', [
                                         'result', 
                                         'needle_psi', 
                                         'needle_kg_cm2', 
                                         'needle_theta', 
                                         'orign',
                                         'direction',
                                         'center',

                                         'lm', 
                                         'inliers',

                                         'kg_cm2_texts',                                          
                                         'psi_texts', 
                                         'kg_cm2_centers',
                                         'psi_centers',
                                         'kg_cm2_theta',
                                         'psi_theta',
                                         'kg_cm2_psi',
                                         'psi'                                 ,
                                         ])

def read_meter(img:Image.Image|str|Path, fl, fl_ft):
    # ai model results
    result = get_ai_model_result(img, fl, fl_ft)
    
    # needle direction
    coords = np.concatenate(result.needle_polygons).reshape(-1, 2)
    orign, direction = estimate_line(coords)
    
    # calculate the meter center 
    circle_props = get_regionprops(result.circle_polygons)
    center = circle_props.centroid[::-1]

    # XXX: the needle direction is from center to orign
    if (orign - center) @ direction < 0:
        direction = -direction

    # calculate the needle theta
    needle_theta = rotate_theta(np.arctan2(direction[1], direction[0]))

    # calulate ocr texts to find kg/cm2 and psi labels
    ocr1, ocr2 = result.ocr1, result.ocr2
    kg_cm2_texts = {}
    psi_texts = {}
    quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
    labels = ocr1['labels']+ocr2['labels']
    for qbox, label in zip(quad_boxes, labels):
        if label in kg_cm2_labels:
            kg_cm2_texts[int(label)]=qbox
        if label in psi_labels:
            psi_texts[int(label)]=qbox
    # calculate the center of kg/cm2 and psi labels
    kg_cm2_centers = np.array(list(kg_cm2_texts.values())).reshape(-1, 4, 2).mean(axis=1)    
    psi_centers = np.array(list(psi_texts.values())).reshape(-1, 4, 2).mean(axis=1)
    
    # convert kg/cm2 and psi labels to polar coordinates, origin is the center of the meter
    # the angle is in degree which is more intuitive
    kg_cm2_coords = kg_cm2_centers - center
    kg_cm2_theta = rotate_theta(np.arctan2(kg_cm2_coords[:, 1], kg_cm2_coords[:, 0]))
    psi_coords = psi_centers - center
    psi_theta = rotate_theta(np.arctan2(psi_coords[:, 1], psi_coords[:, 0]))

    # convert kg_cm2 to psi for fitting a line model
    kg_cm2 = np.array(list(kg_cm2_texts.keys()))
    kg_cm2_psi = kg_cm2 * 14.223    
    # combine kg/cm2 and psi labels to fit a line model
    psi = np.array(list(psi_texts.keys()))
    Y = np.concatenate([kg_cm2_psi, psi])
    X = np.concatenate([kg_cm2_theta, psi_theta])
    data = np.stack([X, Y], axis=1)    
    # run ransac to robustly fit a line model 
    lm, inliers = ransac(data, LineModelND, min_samples=3, 
           residual_threshold=15, 
           max_trials=2)

    # use the model to calculated the needle psi and kg/cm2
    needle_psi = lm.predict(needle_theta)[1]
    needle_kg_cm2 = needle_psi / 14.223

    return MeterResult(result=result,
                          needle_psi=needle_psi,
                          needle_kg_cm2=needle_kg_cm2,
                          needle_theta=needle_theta,
                          orign=orign,
                          direction=direction,
                          center=center,
                          lm=lm,
                          inliers=data[inliers].T,
                          kg_cm2_texts=kg_cm2_texts,
                          psi_texts=psi_texts,
                          kg_cm2_centers=kg_cm2_centers,
                          psi_centers=psi_centers,
                          kg_cm2_theta=kg_cm2_theta,
                          psi_theta=psi_theta,
                          kg_cm2_psi=kg_cm2_psi,
                          psi=psi,
    )


def more_visualization_data(meter_result:MeterResult):
    result = meter_result.result
    center = meter_result.center
    # following calculations are for visualization and debugging
    # calculate the needle head(farest point from center)
    needle_coordinates = np.concatenate(result.needle_polygons).reshape(-1, 2)
    needle_length = np.linalg.norm(needle_coordinates - center,axis=1)
    farest_idx = np.argmax(needle_length)
    needle_head = needle_coordinates[farest_idx]
    needle_head_length = needle_length[farest_idx]
    direction = meter_result.direction * needle_head_length
        
    # inliners data
    inlier_theta, inlier_psi = meter_result.inliers
    
    # predict psi from 0 to 360
    predict_theta = np.linspace(0, 360, 100)
    predict_psi = meter_result.lm.predict(predict_theta)[:, 1]
    return inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction

def visualization(meter_result:MeterResult):
    result = meter_result.result
    center = meter_result.center
    needle_psi, needle_kg_cm2 = meter_result.needle_psi, meter_result.needle_kg_cm2
    inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction = more_visualization_data(meter_result)
    # drawing and visualization
    draw = ImageDraw.Draw(result.img2.copy())    
    # draw needle polygons
    for polygon in result.needle_polygons:
        draw.polygon(polygon, outline='red', width=3)
    
    # draw center circle
    draw = ImageDraw.Draw(draw._image.convert('RGBA'))
    
    draw2 = ImageDraw.Draw(Image.new('RGBA', draw._image.size, (0,0,0,0)))
    for polygon in result.circle_polygons:
        draw2.polygon(polygon, outline='purple', width=1, fill = (255,128,255,100))
    img = Image.alpha_composite(draw._image, draw2._image)
    draw = ImageDraw.Draw(img.convert('RGB'))
    
    # draw needle direction
    draw.line((center[0], center[1], center[0]+direction[0], center[1]+direction[1]), fill='yellow', width=3)
    # draw a dot at center
    draw.ellipse((center[0]-5, center[1]-5, center[0]+5, center[1]+5), outline='yellow', width=3)
    # draw a dot at needle_head
    draw.ellipse((needle_head[0]-5, needle_head[1]-5, needle_head[0]+5, needle_head[1]+5), outline='yellow', width=3)

    for x,y in meter_result.kg_cm2_centers:
        draw.ellipse((x-3, y-3, x+3, y+3), outline='blue', width=3)
    for x,y in meter_result.psi_centers:
        draw.ellipse((x-3, y-3, x+3, y+3), outline='green', width=3)
    for label,quad_box in meter_result.kg_cm2_texts.items():
        draw.polygon(quad_box, outline='blue', width=1)
        draw.text((quad_box[0], quad_box[1]-10), str(label), fill='blue', anchor='ls')
    for label,quad_box in meter_result.psi_texts.items():
        draw.polygon(quad_box, outline='green', width=1)
        draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')

    if len(meter_result.kg_cm2_centers) >4:
        # the ellipse of kg/cm2 labels, currently only for visualization
        em_params, theta, (c, r) = estimate_ellipse(meter_result.kg_cm2_centers)
        y, x = em_params[:2]
        draw.ellipse((x-5, y-5, x+5, y+5), outline='blue', width=1)
        imga = np.array(draw._image)
        imga[c,r] = (0, 0, 255)
        draw = ImageDraw.Draw(Image.fromarray(imga))

    if len(meter_result.psi_centers) >4:
        # the ellipse of psi labels, currently only for visualization
        em_params, theta, (c, r) = estimate_ellipse(meter_result.psi_centers)
        draw.ellipse((x-5, y-5, x+5, y+5), outline='green', width=1)
        imga = np.array(draw._image)
        imga[c,r] = (0, 255, 0)
        y, x = em_params[:2]
        draw = ImageDraw.Draw(Image.fromarray(imga))
    draw.text((needle_head[0]-10, needle_head[1]-10),
              f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}',anchor='ls',
              fill='yellow')
    plt.plot(predict_theta, predict_psi, color='red', alpha=0.5)
    plt.plot(meter_result.kg_cm2_theta, meter_result.kg_cm2_psi, 'o', color='#77F')
    plt.plot(meter_result.psi_theta, meter_result.psi, 'o', color='#7F7')
    plt.plot(inlier_theta, inlier_psi, 'x', color='red', alpha=0.5)
    plt.vlines(meter_result.needle_theta, 0, 160, colors='yellow', alpha=0.5)
    plt.hlines(meter_result.needle_psi, 0, 360, colors='yellow', alpha=0.5)

    plt.text(meter_result.needle_theta-20, meter_result.needle_psi-20, 
             f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}', color='yellow')
    plt.xlim(0, 360)
    plt.ylim(0, 160)
    return draw._image, plt.gcf()

def clear_cache():
    cached_results.clear()
def save_cache():
    pickle.dump(cached_results, open('cached_results.pkl', 'wb'))
def load_cache():
    global cached_results
    cached_results = pickle.load(open('cached_results.pkl', 'rb'))
#%%
if __name__ == '__main__':
    from io import BytesIO
    from IPython.display import display
    fl, fl_ft = model_init(hack=False)
    #load_cache()
    clear_cache()
    imgs = list(Path('images/good').glob('*.jpg'))#[-1:]
    W, H = 640, 480
    for img_fn in imgs:
        print(img_fn)
        meter_result = read_meter(img_fn, fl, fl_ft)
        img, fig = visualization(meter_result)
        # resize draw._image to fit WxH and keep aspect ratio
        w, h = meter_result.result.img2.size    
        if w/W > h/H:
            w, h = W, int(h*W/w)
        else:
            w, h = int(w*H/h), H
        display(img.resize((w, h)))
        # convert figure to PIL image using io.BytesIO
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        fig_img = Image.open(buf)
        display(fig_img)
        # clear plot 
        plt.clf()






# %%