|
"""Utils for visual iterative prompting. |
|
|
|
A number of utility functions for VIP. |
|
""" |
|
|
|
import re |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import scipy.spatial.distance as distance |
|
|
|
|
|
def min_dist(coord, coords): |
|
if not coords: |
|
return np.inf |
|
xys = np.asarray([[coord.xy] for coord in coords]) |
|
return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min() |
|
|
|
|
|
def coord_outside_image(coord, image, radius): |
|
(height, image_width, _) = image.shape |
|
x, y = coord.xy |
|
x_outside = x > image_width - 2 * radius or x < 2 * radius |
|
y_outside = y > height - 2 * radius or y < 2 * radius |
|
return x_outside or y_outside |
|
|
|
|
|
def is_invalid_coord(coord, coords, radius, image): |
|
|
|
pos_overlaps = min_dist(coord, coords) < 1.5 * radius |
|
return pos_overlaps or coord_outside_image(coord, image, radius) |
|
|
|
|
|
def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40): |
|
x, y = arm_coord |
|
x += int(np.cos(angle) * mag) |
|
y += int(np.sin(angle) * mag) |
|
if is_circle: |
|
x += int(np.cos(angle) * radius * np.sign(mag)) |
|
y += int(np.sin(angle) * radius * np.sign(mag)) |
|
return x, y |
|
|
|
|
|
def coord_to_text_coord(coord, arm_coord, radius): |
|
delta_coord = np.asarray(coord.xy) - arm_coord |
|
if np.linalg.norm(delta_coord) == 0: |
|
return arm_coord |
|
return ( |
|
int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)), |
|
int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)), |
|
) |
|
|
|
|
|
def parse_response(response, answer_key='Arrow: ['): |
|
values = [] |
|
if answer_key in response: |
|
print('parse_response from answer_key') |
|
arrow_response = response.split(answer_key)[-1].split(']')[0] |
|
for val in map(int, re.findall(r'\d+', arrow_response)): |
|
values.append(val) |
|
else: |
|
print('parse_response for all ints') |
|
for val in map(int, re.findall(r'\d+', response)): |
|
values.append(val) |
|
return values |
|
|
|
|
|
def compute_errors(action, true_action, verbose=False): |
|
"""Compute errors between a predicted action and true action.""" |
|
l2_error = np.linalg.norm(action - true_action) |
|
cos_sim = 1 - distance.cosine(action, true_action) |
|
l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:]) |
|
cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:]) |
|
z_error = np.abs(action[0] - true_action[0]) |
|
errors = { |
|
'l2': l2_error, |
|
'cos_sim': cos_sim, |
|
'l2_xy_error': l2_xy_error, |
|
'cos_xy_sim': cos_xy_sim, |
|
'z_error': z_error, |
|
} |
|
|
|
if verbose: |
|
print('action: \t', [f'{a:.3f}' for a in action]) |
|
print('true_action \t', [f'{a:.3f}' for a in true_action]) |
|
print(f'l2: \t\t{l2_error:.3f}') |
|
print(f'l2_xy_error: \t{l2_xy_error:.3f}') |
|
print(f'cos_sim: \t{cos_sim:.3f}') |
|
print(f'cos_xy_sim: \t{cos_xy_sim:.3f}') |
|
print(f'z_error: \t{z_error:.3f}') |
|
|
|
return errors |
|
|
|
|
|
def plot_errors(all_errors, error_types=None): |
|
"""Plot errors across iterations.""" |
|
if error_types is None: |
|
error_types = [ |
|
'l2', |
|
'l2_xy_error', |
|
'z_error', |
|
'cos_sim', |
|
'cos_xy_sim', |
|
] |
|
|
|
_, axs = plt.subplots(2, 3, figsize=(15, 8)) |
|
for i, error_type in enumerate(error_types): |
|
all_iter_errors = {} |
|
for error_by_iter in all_errors: |
|
for itr in error_by_iter: |
|
if itr in all_iter_errors: |
|
all_iter_errors[itr].append(error_by_iter[itr][error_type]) |
|
else: |
|
all_iter_errors[itr] = [error_by_iter[itr][error_type]] |
|
|
|
mean_iter_errors = [ |
|
np.mean(all_iter_errors[itr]) for itr in all_iter_errors |
|
] |
|
|
|
axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors) |
|
axs[i // 3, i % 3].set_title(error_type) |
|
plt.show() |
|
|