Spaces:
Runtime error
Runtime error
from PIL import Image | |
from matplotlib import pyplot as plt | |
import textwrap | |
def to_gif(images, path): | |
images[0].save(path, save_all=True, | |
append_images=images[1:], loop=0, duration=len(images) * 20) | |
def figure_to_image(figure): | |
figure.set_dpi(300) | |
figure.canvas.draw() | |
return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb()) | |
def image_grid(images, outpath=None, column_titles=None, row_titles=None): | |
n_rows = len(images) | |
n_cols = len(images[0]) | |
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, | |
figsize=(n_cols, n_rows), squeeze=False) | |
for row, _images in enumerate(images): | |
for column, image in enumerate(_images): | |
ax = axs[row][column] | |
ax.imshow(image) | |
if column_titles and row == 0: | |
ax.set_title(textwrap.fill( | |
column_titles[column], width=12), fontsize='x-small') | |
if row_titles and column == 0: | |
ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row])) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.subplots_adjust(wspace=0, hspace=0) | |
if outpath is not None: | |
plt.savefig(outpath, bbox_inches='tight', dpi=300) | |
plt.close() | |
else: | |
plt.tight_layout(pad=0) | |
image = figure_to_image(plt.gcf()) | |
plt.close() | |
return image | |
def get_module(module, module_name): | |
if isinstance(module_name, str): | |
module_name = module_name.split('.') | |
if len(module_name) == 0: | |
return module | |
else: | |
module = getattr(module, module_name[0]) | |
return get_module(module, module_name[1:]) | |
def set_module(module, module_name, new_module): | |
if isinstance(module_name, str): | |
module_name = module_name.split('.') | |
if len(module_name) == 1: | |
return setattr(module, module_name[0], new_module) | |
else: | |
module = getattr(module, module_name[0]) | |
return set_module(module, module_name[1:], new_module) | |
def freeze(module): | |
for parameter in module.parameters(): | |
parameter.requires_grad = False | |
def unfreeze(module): | |
for parameter in module.parameters(): | |
parameter.requires_grad = True | |
def get_concat_h(im1, im2): | |
dst = Image.new('RGB', (im1.width + im2.width, im1.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (im1.width, 0)) | |
return dst | |
def get_concat_v(im1, im2): | |
dst = Image.new('RGB', (im1.width, im1.height + im2.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (0, im1.height)) | |
return dst |