gauge / aimodel.py
tjw's picture
update
f5deb4a
raw
history blame
16.2 kB
# %%
import matplotlib.style
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import pickle
import torch
from pathlib import Path
from PIL import Image
from PIL import ImageDraw
import numpy as np
from collections import namedtuple
from logging import getLogger
logger = getLogger(__name__)
# %%
class Florence:
def __init__(self, model_id:str, hack=False):
if hack:
return
self.model = (
AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, torch_dtype="auto",
trust_remote_code=True
)
.eval()
.cuda()
)
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.model_id = model_id
def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
logger.debug(f"run {task_prompt} {extra_text}")
model, processor = self.model, self.processor
prompt = task_prompt + (extra_text if extra_text else "")
inputs = processor(text=prompt, images=img, return_tensors="pt").to(
"cuda", torch.float16
)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
#temperature=0.1,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(img.width, img.height),
)
return parsed_answer
def model_init(hack=False):
fl = Florence("microsoft/Florence-2-large", hack=hack)
fl_ft = Florence("microsoft/Florence-2-large-ft", hack=hack)
return fl, fl_ft
#%%
# florence-2 tasks
TASK_OD = "<OD>"
TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
TASK_OCR = "<OCR_WITH_REGION>"
TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
#%%
AIModelResult = namedtuple('AIModelResult',
['img', 'img2', 'meter_bbox', 'needle_polygons', 'circle_polygons', 'ocr1', 'ocr2'])
cached_results:dict[str, AIModelResult] = {}
#%%
def get_meter_bbox(fl:Florence, img:Image):
task_prompt, extra_text = TASK_GROUNDING, "a circular meter with white background"
parsed_answer = fl.run(img, task_prompt, extra_text)
assert len(parsed_answer) == 1
k,v = parsed_answer.popitem()
assert 'bboxes' in v
assert 'labels' in v
assert len(v['bboxes']) == 1
assert len(v['labels']) == 1
assert v['labels'][0] == 'a circular meter'
bbox = v['bboxes'][0]
return bbox
def get_circles(fl:Florence, img2:Image, polygons:list):
img3 = Image.new('L', img2.size, color = 'black')
draw = ImageDraw.Draw(img3)
for polygon in polygons:
draw.polygon(polygon, outline='white', width=3, fill='white')
img2a = np.where(np.array(img3)[:,:,None]>0, np.array(img2), 255)
img4 = Image.fromarray(img2a)
parsed_answer = fl.run(img4, TASK_SEGMENTATION, "a circle")
assert len(parsed_answer) == 1
k,v = parsed_answer.popitem()
assert 'polygons' in v
assert len(v['polygons']) == 1
return v['polygons'][0]
def get_needle_polygons(fl:Florence, img2:Image):
parsed_answer = fl.run(img2, TASK_SEGMENTATION, "the long narrow black needle hand pass through the center of the cicular meter")
assert len(parsed_answer) == 1
k,v = parsed_answer.popitem()
assert 'polygons' in v
assert len(v['polygons']) == 1
needle_polygons = v['polygons'][0]
return needle_polygons
def get_ocr(fl:Florence, img2:Image):
parsed_answer = fl.run(img2, TASK_OCR)
assert len(parsed_answer)==1
k,v = parsed_answer.popitem()
return v
def get_ai_model_result(img:Image.Image|Path|str, fl:Florence, fl_ft:Florence):
if isinstance(img, Path):
key = img.parts[-1]
elif isinstance(img, str):
key = img.split('/')[-1]
else:
key = None
if key is not None and key in cached_results:
return cached_results[key]
if isinstance(img, (Path, str)):
img = Image.open(img)
meter_bbox = get_meter_bbox(fl, img)
img2 = img.crop(meter_bbox)
needle_polygons = get_needle_polygons(fl, img2)
result = AIModelResult(img, img2, meter_bbox, needle_polygons,
get_circles(fl, img2, needle_polygons),
get_ocr(fl, img2),
get_ocr(fl_ft, img2)
)
if key is not None:
cached_results[key] = result
return result
#%%
from skimage.measure import regionprops
from skimage.measure import EllipseModel
from skimage.draw import ellipse_perimeter
def get_regionprops(polygons:list) -> regionprops:
coords = np.concatenate(polygons).reshape(-1, 2)
size = tuple( (coords.max(axis=0)+2).astype('int') )
img = Image.new('L', size, color = 'black')
# draw circle polygon
draw = ImageDraw.Draw(img)
for polygon in polygons:
draw.polygon(polygon, outline='white', width=1, fill='white')
# use skimage to find the mass center of the circle
circle_imga = (np.array(img)>0).astype(np.uint8)
property = regionprops(circle_imga)[0]
return property
def estimate_ellipse(coords, enlarge_factor=1.0):
em = EllipseModel()
em.estimate(coords[:, ::-1])
y, x, a, b, theta = em.params
a, b = a*enlarge_factor, b*enlarge_factor
em_params = np.round([y,x, a, b]).astype('int')
c, r = ellipse_perimeter(*em_params, orientation=-theta)
return em_params, theta, (c, r)
def estimate_line(coords):
lm = LineModelND()
lm.estimate(coords)
return lm.params
#%%
#%%
from matplotlib import pyplot as plt
import matplotlib
from skimage.measure import LineModelND, ransac
matplotlib.style.use('dark_background')
def rotate_theta(theta):
return ((theta + 3*np.pi/2)%(2*np.pi))/(2*np.pi)*360
kg_cm2_labels = list(map(str, [1,3,5,7,9,11]))
psi_labels = list(map(str, range(20, 180, 20)))
# lousy decoupling
MeterResult = namedtuple('MeterResult', [
'result',
'needle_psi',
'needle_kg_cm2',
'needle_theta',
'orign',
'direction',
'center',
'lm',
'inliers',
'kg_cm2_texts',
'psi_texts',
'kg_cm2_centers',
'psi_centers',
'kg_cm2_theta',
'psi_theta',
'kg_cm2_psi',
'psi' ,
])
def read_meter(img:Image.Image|str|Path, fl, fl_ft):
# ai model results
result = get_ai_model_result(img, fl, fl_ft)
# needle direction
coords = np.concatenate(result.needle_polygons).reshape(-1, 2)
orign, direction = estimate_line(coords)
# calculate the meter center
circle_props = get_regionprops(result.circle_polygons)
center = circle_props.centroid[::-1]
# XXX: the needle direction is from center to orign
if (orign - center) @ direction < 0:
direction = -direction
# calculate the needle theta
needle_theta = rotate_theta(np.arctan2(direction[1], direction[0]))
# calulate ocr texts to find kg/cm2 and psi labels
ocr1, ocr2 = result.ocr1, result.ocr2
kg_cm2_texts = {}
psi_texts = {}
quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
labels = ocr1['labels']+ocr2['labels']
for qbox, label in zip(quad_boxes, labels):
if label in kg_cm2_labels:
kg_cm2_texts[int(label)]=qbox
if label in psi_labels:
psi_texts[int(label)]=qbox
# calculate the center of kg/cm2 and psi labels
kg_cm2_centers = np.array(list(kg_cm2_texts.values())).reshape(-1, 4, 2).mean(axis=1)
psi_centers = np.array(list(psi_texts.values())).reshape(-1, 4, 2).mean(axis=1)
# convert kg/cm2 and psi labels to polar coordinates, origin is the center of the meter
# the angle is in degree which is more intuitive
kg_cm2_coords = kg_cm2_centers - center
kg_cm2_theta = rotate_theta(np.arctan2(kg_cm2_coords[:, 1], kg_cm2_coords[:, 0]))
psi_coords = psi_centers - center
psi_theta = rotate_theta(np.arctan2(psi_coords[:, 1], psi_coords[:, 0]))
# convert kg_cm2 to psi for fitting a line model
kg_cm2 = np.array(list(kg_cm2_texts.keys()))
kg_cm2_psi = kg_cm2 * 14.223
# combine kg/cm2 and psi labels to fit a line model
psi = np.array(list(psi_texts.keys()))
Y = np.concatenate([kg_cm2_psi, psi])
X = np.concatenate([kg_cm2_theta, psi_theta])
data = np.stack([X, Y], axis=1)
# run ransac to robustly fit a line model
lm, inliers = ransac(data, LineModelND, min_samples=2,
residual_threshold=15,
max_trials=2)
# use the model to calculated the needle psi and kg/cm2
needle_psi = lm.predict(needle_theta)[1]
needle_kg_cm2 = needle_psi / 14.223
return MeterResult(result=result,
needle_psi=needle_psi,
needle_kg_cm2=needle_kg_cm2,
needle_theta=needle_theta,
orign=orign,
direction=direction,
center=center,
lm=lm,
inliers=data[inliers].T,
kg_cm2_texts=kg_cm2_texts,
psi_texts=psi_texts,
kg_cm2_centers=kg_cm2_centers,
psi_centers=psi_centers,
kg_cm2_theta=kg_cm2_theta,
psi_theta=psi_theta,
kg_cm2_psi=kg_cm2_psi,
psi=psi,
)
def more_visualization_data(meter_result:MeterResult):
result = meter_result.result
center = meter_result.center
# following calculations are for visualization and debugging
# calculate the needle head(farest point from center)
needle_coordinates = np.concatenate(result.needle_polygons).reshape(-1, 2)
needle_length = np.linalg.norm(needle_coordinates - center,axis=1)
farest_idx = np.argmax(needle_length)
needle_head = needle_coordinates[farest_idx]
needle_head_length = needle_length[farest_idx]
direction = meter_result.direction * needle_head_length
# inliners data
inlier_theta, inlier_psi = meter_result.inliers
# predict psi from 0 to 360
predict_theta = np.linspace(0, 360, 100)
predict_psi = meter_result.lm.predict(predict_theta)[:, 1]
return inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction
def visualization(meter_result:MeterResult):
result = meter_result.result
center = meter_result.center
needle_psi, needle_kg_cm2 = meter_result.needle_psi, meter_result.needle_kg_cm2
inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction = more_visualization_data(meter_result)
# drawing and visualization
draw = ImageDraw.Draw(result.img2.copy())
# draw needle polygons
for polygon in result.needle_polygons:
draw.polygon(polygon, outline='red', width=3)
# draw center circle
draw = ImageDraw.Draw(draw._image.convert('RGBA'))
draw2 = ImageDraw.Draw(Image.new('RGBA', draw._image.size, (0,0,0,0)))
for polygon in result.circle_polygons:
draw2.polygon(polygon, outline='purple', width=1, fill = (255,128,255,100))
img = Image.alpha_composite(draw._image, draw2._image)
draw = ImageDraw.Draw(img.convert('RGB'))
# draw needle direction
draw.line((center[0], center[1], center[0]+direction[0], center[1]+direction[1]), fill='yellow', width=3)
# draw a dot at center
draw.ellipse((center[0]-5, center[1]-5, center[0]+5, center[1]+5), outline='yellow', width=3)
# draw a dot at needle_head
draw.ellipse((needle_head[0]-5, needle_head[1]-5, needle_head[0]+5, needle_head[1]+5), outline='yellow', width=3)
for x,y in meter_result.kg_cm2_centers:
draw.ellipse((x-3, y-3, x+3, y+3), outline='blue', width=3)
for x,y in meter_result.psi_centers:
draw.ellipse((x-3, y-3, x+3, y+3), outline='green', width=3)
for label,quad_box in meter_result.kg_cm2_texts.items():
draw.polygon(quad_box, outline='blue', width=1)
draw.text((quad_box[0], quad_box[1]-10), str(label), fill='blue', anchor='ls')
for label,quad_box in meter_result.psi_texts.items():
draw.polygon(quad_box, outline='green', width=1)
draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')
if len(meter_result.kg_cm2_centers) >4:
# the ellipse of kg/cm2 labels, currently only for visualization
em_params, theta, (c, r) = estimate_ellipse(meter_result.kg_cm2_centers)
y, x = em_params[:2]
draw.ellipse((x-5, y-5, x+5, y+5), outline='blue', width=1)
imga = np.array(draw._image)
imga[c,r] = (0, 0, 255)
draw = ImageDraw.Draw(Image.fromarray(imga))
if len(meter_result.psi_centers) >4:
# the ellipse of psi labels, currently only for visualization
em_params, theta, (c, r) = estimate_ellipse(meter_result.psi_centers)
draw.ellipse((x-5, y-5, x+5, y+5), outline='green', width=1)
imga = np.array(draw._image)
imga[c,r] = (0, 255, 0)
y, x = em_params[:2]
draw = ImageDraw.Draw(Image.fromarray(imga))
draw.text((needle_head[0]-10, needle_head[1]-10),
f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}',anchor='ls',
fill='yellow')
plt.plot(predict_theta, predict_psi, color='red', alpha=0.5)
plt.plot(meter_result.kg_cm2_theta, meter_result.kg_cm2_psi, 'o', color='#77F')
plt.plot(meter_result.psi_theta, meter_result.psi, 'o', color='#7F7')
plt.plot(inlier_theta, inlier_psi, 'x', color='red', alpha=0.5)
plt.vlines(meter_result.needle_theta, 0, 160, colors='yellow', alpha=0.5)
plt.hlines(meter_result.needle_psi, 0, 360, colors='yellow', alpha=0.5)
plt.text(meter_result.needle_theta-20, meter_result.needle_psi-20,
f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}', color='yellow')
plt.xlim(0, 360)
plt.ylim(0, 160)
return draw._image, plt.gcf()
def clear_cache():
cached_results.clear()
def save_cache():
pickle.dump(cached_results, open('cached_results.pkl', 'wb'))
def load_cache():
global cached_results
cached_results = pickle.load(open('cached_results.pkl', 'rb'))
#%%
if __name__ == '__main__':
from io import BytesIO
from IPython.display import display
fl, fl_ft = model_init(hack=False)
#load_cache()
clear_cache()
imgs = list(Path('images/good').glob('*.jpg'))#[-1:]
W, H = 640, 480
for img_fn in imgs:
print(img_fn)
meter_result = read_meter(img_fn, fl, fl_ft)
img, fig = visualization(meter_result)
# resize draw._image to fit WxH and keep aspect ratio
w, h = meter_result.result.img2.size
if w/W > h/H:
w, h = W, int(h*W/w)
else:
w, h = int(w*H/h), H
display(img.resize((w, h)))
# convert figure to PIL image using io.BytesIO
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
fig_img = Image.open(buf)
display(fig_img)
# clear plot
plt.clf()
# %%