Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
import requests | |
import torch.distributed as dist | |
import torchvision.utils | |
from imaginaire.utils.distributed import is_master | |
def save_pilimage_in_jpeg(fullname, output_img): | |
r"""Save PIL Image to JPEG. | |
Args: | |
fullname (str): Full save path. | |
output_img (PIL Image): Image to be saved. | |
""" | |
dirname = os.path.dirname(fullname) | |
os.makedirs(dirname, exist_ok=True) | |
output_img.save(fullname, 'JPEG', quality=99) | |
def save_intermediate_training_results( | |
visualization_images, logdir, current_epoch, current_iteration): | |
r"""Save intermediate training results for debugging purpose. | |
Args: | |
visualization_images (tensor): Image where pixel values are in [-1, 1]. | |
logdir (str): Where to save the image. | |
current_epoch (int): Current training epoch. | |
current_iteration (int): Current training iteration. | |
""" | |
visualization_images = (visualization_images + 1) / 2 | |
output_filename = os.path.join( | |
logdir, 'images', | |
'epoch_{:05}iteration{:09}.jpg'.format( | |
current_epoch, current_iteration)) | |
print('Save output images to {}'.format(output_filename)) | |
os.makedirs(os.path.dirname(output_filename), exist_ok=True) | |
image_grid = torchvision.utils.make_grid( | |
visualization_images.data, nrow=1, padding=0, normalize=False) | |
torchvision.utils.save_image(image_grid, output_filename, nrow=1) | |
def download_file_from_google_drive(URL, destination): | |
r"""Download a file from google drive. | |
Args: | |
URL: GDrive file ID. | |
destination: Path to save the file. | |
Returns: | |
""" | |
download_file(f"https://docs.google.com/uc?export=download&id={URL}", destination) | |
def download_file(URL, destination): | |
r"""Download a file from google drive or pbss by using the url. | |
Args: | |
URL: GDrive URL or PBSS pre-signed URL for the checkpoint. | |
destination: Path to save the file. | |
Returns: | |
""" | |
session = requests.Session() | |
response = session.get(URL, stream=True) | |
token = get_confirm_token(response) | |
if token: | |
params = {'confirm': token} | |
response = session.get(URL, params=params, stream=True) | |
save_response_content(response, destination) | |
def get_confirm_token(response): | |
r"""Get confirm token | |
Args: | |
response: Check if the file exists. | |
Returns: | |
""" | |
for key, value in response.cookies.items(): | |
if key.startswith('download_warning'): | |
return value | |
return None | |
def save_response_content(response, destination): | |
r"""Save response content | |
Args: | |
response: | |
destination: Path to save the file. | |
Returns: | |
""" | |
chunk_size = 32768 | |
with open(destination, "wb") as f: | |
for chunk in response.iter_content(chunk_size): | |
if chunk: | |
f.write(chunk) | |
def get_checkpoint(checkpoint_path, url=''): | |
r"""Get the checkpoint path. If it does not exist yet, download it from | |
the url. | |
Args: | |
checkpoint_path (str): Checkpoint path. | |
url (str): URL to download checkpoint. | |
Returns: | |
(str): Full checkpoint path. | |
""" | |
if 'TORCH_HOME' not in os.environ: | |
os.environ['TORCH_HOME'] = os.getcwd() | |
save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') | |
os.makedirs(save_dir, exist_ok=True) | |
full_checkpoint_path = os.path.join(save_dir, checkpoint_path) | |
if not os.path.exists(full_checkpoint_path): | |
os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) | |
if is_master(): | |
print('Downloading {}'.format(url)) | |
if 'pbss.s8k.io' not in url: | |
url = f"https://docs.google.com/uc?export=download&id={url}" | |
download_file(url, full_checkpoint_path) | |
if dist.is_available() and dist.is_initialized(): | |
dist.barrier() | |
return full_checkpoint_path | |