venite's picture
initial
f670afc
raw
history blame
4.12 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os
import requests
import torch.distributed as dist
import torchvision.utils
from imaginaire.utils.distributed import is_master
def save_pilimage_in_jpeg(fullname, output_img):
r"""Save PIL Image to JPEG.
Args:
fullname (str): Full save path.
output_img (PIL Image): Image to be saved.
"""
dirname = os.path.dirname(fullname)
os.makedirs(dirname, exist_ok=True)
output_img.save(fullname, 'JPEG', quality=99)
def save_intermediate_training_results(
visualization_images, logdir, current_epoch, current_iteration):
r"""Save intermediate training results for debugging purpose.
Args:
visualization_images (tensor): Image where pixel values are in [-1, 1].
logdir (str): Where to save the image.
current_epoch (int): Current training epoch.
current_iteration (int): Current training iteration.
"""
visualization_images = (visualization_images + 1) / 2
output_filename = os.path.join(
logdir, 'images',
'epoch_{:05}iteration{:09}.jpg'.format(
current_epoch, current_iteration))
print('Save output images to {}'.format(output_filename))
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
image_grid = torchvision.utils.make_grid(
visualization_images.data, nrow=1, padding=0, normalize=False)
torchvision.utils.save_image(image_grid, output_filename, nrow=1)
def download_file_from_google_drive(URL, destination):
r"""Download a file from google drive.
Args:
URL: GDrive file ID.
destination: Path to save the file.
Returns:
"""
download_file(f"https://docs.google.com/uc?export=download&id={URL}", destination)
def download_file(URL, destination):
r"""Download a file from google drive or pbss by using the url.
Args:
URL: GDrive URL or PBSS pre-signed URL for the checkpoint.
destination: Path to save the file.
Returns:
"""
session = requests.Session()
response = session.get(URL, stream=True)
token = get_confirm_token(response)
if token:
params = {'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def get_confirm_token(response):
r"""Get confirm token
Args:
response: Check if the file exists.
Returns:
"""
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
r"""Save response content
Args:
response:
destination: Path to save the file.
Returns:
"""
chunk_size = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)
def get_checkpoint(checkpoint_path, url=''):
r"""Get the checkpoint path. If it does not exist yet, download it from
the url.
Args:
checkpoint_path (str): Checkpoint path.
url (str): URL to download checkpoint.
Returns:
(str): Full checkpoint path.
"""
if 'TORCH_HOME' not in os.environ:
os.environ['TORCH_HOME'] = os.getcwd()
save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints')
os.makedirs(save_dir, exist_ok=True)
full_checkpoint_path = os.path.join(save_dir, checkpoint_path)
if not os.path.exists(full_checkpoint_path):
os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True)
if is_master():
print('Downloading {}'.format(url))
if 'pbss.s8k.io' not in url:
url = f"https://docs.google.com/uc?export=download&id={url}"
download_file(url, full_checkpoint_path)
if dist.is_available() and dist.is_initialized():
dist.barrier()
return full_checkpoint_path