# python3.7 """Contains the base class for generator.""" import os import sys import logging import numpy as np import torch from . import model_settings __all__ = ['BaseGenerator'] def get_temp_logger(logger_name='logger'): """Gets a temporary logger. This logger will print all levels of messages onto the screen. Args: logger_name: Name of the logger. Returns: A `logging.Logger`. Raises: ValueError: If the input `logger_name` is empty. """ if not logger_name: raise ValueError(f'Input `logger_name` should not be empty!') logger = logging.getLogger(logger_name) if not logger.hasHandlers(): logger.setLevel(logging.DEBUG) formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s") sh = logging.StreamHandler(stream=sys.stdout) sh.setLevel(logging.DEBUG) sh.setFormatter(formatter) logger.addHandler(sh) return logger class BaseGenerator(object): """Base class for generator used in GAN variants. NOTE: The model should be defined with pytorch, and only used for inference. """ def __init__(self, model_name, logger=None): """Initializes with specific settings. The model should be registered in `model_settings.py` with proper settings first. Among them, some attributes are necessary, including: (1) gan_type: Type of the GAN model. (2) latent_space_dim: Dimension of the latent space. Should be a tuple. (3) resolution: Resolution of the synthesis. (4) min_val: Minimum value of the raw output. (default -1.0) (5) max_val: Maximum value of the raw output. (default 1.0) (6) channel_order: Channel order of the output image. (default: `RGB`) Args: model_name: Name with which the model is registered. logger: Logger for recording log messages. If set as `None`, a default logger, which prints messages from all levels to screen, will be created. (default: None) Raises: AttributeError: If some necessary attributes are missing. """ self.model_name = model_name for key, val in model_settings.MODEL_POOL[model_name].items(): setattr(self, key, val) self.use_cuda = model_settings.USE_CUDA self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE self.logger = logger or get_temp_logger(model_name + '_generator') self.model = None self.run_device = 'cuda' if self.use_cuda else 'cpu' self.cpu_device = 'cpu' # Check necessary settings. self.check_attr('gan_type') self.check_attr('latent_space_dim') self.check_attr('resolution') self.min_val = getattr(self, 'min_val', -1.0) self.max_val = getattr(self, 'max_val', 1.0) self.output_channels = getattr(self, 'output_channels', 3) self.channel_order = getattr(self, 'channel_order', 'RGB').upper() assert self.channel_order in ['RGB', 'BGR'] # Build model and load pre-trained weights. self.build() if os.path.isfile(getattr(self, 'model_path', '')): self.load() elif os.path.isfile(getattr(self, 'tf_model_path', '')): self.convert_tf_model() else: self.logger.warning(f'No pre-trained model will be loaded!') # Change to inference mode and GPU mode if needed. assert self.model self.model.eval().to(self.run_device) def check_attr(self, attr_name): """Checks the existence of a particular attribute. Args: attr_name: Name of the attribute to check. Raises: AttributeError: If the target attribute is missing. """ if not hasattr(self, attr_name): raise AttributeError( f'`{attr_name}` is missing for model `{self.model_name}`!') def build(self): """Builds the graph.""" raise NotImplementedError(f'Should be implemented in derived class!') def load(self): """Loads pre-trained weights.""" raise NotImplementedError(f'Should be implemented in derived class!') def convert_tf_model(self, test_num=10): """Converts models weights from tensorflow version. Args: test_num: Number of images to generate for testing whether the conversion is done correctly. `0` means skipping the test. (default 10) """ raise NotImplementedError(f'Should be implemented in derived class!') def sample(self, num): """Samples latent codes randomly. Args: num: Number of latent codes to sample. Should be positive. Returns: A `numpy.ndarray` as sampled latend codes. """ raise NotImplementedError(f'Should be implemented in derived class!') def preprocess(self, latent_codes): """Preprocesses the input latent code if needed. Args: latent_codes: The input latent codes for preprocessing. Returns: The preprocessed latent codes which can be used as final input for the generator. """ raise NotImplementedError(f'Should be implemented in derived class!') def easy_sample(self, num): """Wraps functions `sample()` and `preprocess()` together.""" return self.preprocess(self.sample(num)) def synthesize(self, latent_codes): """Synthesizes images with given latent codes. NOTE: The latent codes should have already been preprocessed. Args: latent_codes: Input latent codes for image synthesis. Returns: A dictionary whose values are raw outputs from the generator. """ raise NotImplementedError(f'Should be implemented in derived class!') def get_value(self, tensor): """Gets value of a `torch.Tensor`. Args: tensor: The input tensor to get value from. Returns: A `numpy.ndarray`. Raises: ValueError: If the tensor is with neither `torch.Tensor` type or `numpy.ndarray` type. """ if isinstance(tensor, np.ndarray): return tensor if isinstance(tensor, torch.Tensor): return tensor.to(self.cpu_device).detach().numpy() raise ValueError(f'Unsupported input type `{type(tensor)}`!') def postprocess(self, images): """Postprocesses the output images if needed. This function assumes the input numpy array is with shape [batch_size, channel, height, width]. Here, `channel = 3` for color image and `channel = 1` for grayscale image. The return images are with shape [batch_size, height, width, channel]. NOTE: The channel order of output image will always be `RGB`. Args: images: The raw output from the generator. Returns: The postprocessed images with dtype `numpy.uint8` with range [0, 255]. Raises: ValueError: If the input `images` are not with type `numpy.ndarray` or not with shape [batch_size, channel, height, width]. """ if not isinstance(images, np.ndarray): raise ValueError(f'Images should be with type `numpy.ndarray`!') if ('stylegan3' not in self.model_name) and ('stylegan2' not in self.model_name): images_shape = images.shape if len(images_shape) != 4 or images_shape[1] not in [1, 3]: raise ValueError(f'Input should be with shape [batch_size, channel, ' f'height, width], where channel equals to 1 or 3. ' f'But {images_shape} is received!') images = (images - self.min_val) * 255 / (self.max_val - self.min_val) images = np.clip(images + 0.5, 0, 255).astype(np.uint8) images = images.transpose(0, 2, 3, 1) if self.channel_order == 'BGR': images = images[:, :, :, ::-1] return images def easy_synthesize(self, latent_codes, **kwargs): """Wraps functions `synthesize()` and `postprocess()` together.""" outputs = self.synthesize(latent_codes, **kwargs) if 'image' in outputs: outputs['image'] = self.postprocess(outputs['image']) return outputs def get_batch_inputs(self, latent_codes): """Gets batch inputs from a collection of latent codes. This function will yield at most `self.batch_size` latent_codes at a time. Args: latent_codes: The input latent codes for generation. First dimension should be the total number. """ total_num = latent_codes.shape[0] for i in range(0, total_num, self.batch_size): yield latent_codes[i:i + self.batch_size]