Spaces:
Runtime error
Runtime error
File size: 6,971 Bytes
1ba3df3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from pathlib import Path
from typing import Dict, List, Union, Tuple
from omegaconf import OmegaConf
import numpy as np
import torch
from torch import nn
from PIL import Image, ImageDraw, ImageFont
import models
GENERATOR_PREFIX = "networks.g."
WHITE = 255
EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E']
class InferenceServicer:
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
self.hp = hp
self.imsize = imsize
if gpu_id is None:
self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu'
else:
self.device = torch.device(f'cuda:{gpu_id}')
model_config = self.hp.models.G
self.model: nn.Module = models.Generator(model_config)
# Load Generator model weight
model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu')
generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl)
self.model.load_state_dict(generator_state_dict)
self.model.to(device=self.device)
self.model.eval()
# Setting Content font files
self.content_character_dict = self.load_content_character_dict(Path(content_image_dir))
@staticmethod
def convert_generator_state_dict(model_state_dict_pl):
generator_prefix = GENERATOR_PREFIX
generator_state_dict = {}
for module_name, module_state in model_state_dict_pl['state_dict'].items():
if module_name.startswith(generator_prefix):
generator_state_dict[module_name[len(generator_prefix):]] = module_state
return generator_state_dict
@staticmethod
def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]:
content_character_dict = {}
for filepath in content_image_dir.glob("**/*.png"):
content_character_dict[filepath.stem] = filepath
return content_character_dict
@staticmethod
def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image:
bg_img = bg_img.copy()
item_img = item_img.copy()
item_w, item_h = item_img.size
W, H = bg_img.size
if fit:
item_ratio = item_w / item_h
bg_ratio = W / H
if bg_ratio > item_ratio:
# height fitting
resize_ratio = H / item_h
else:
# width fitting
resize_ratio = W / item_w
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
item_w, item_h = item_img.size
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
return bg_img
def set_image(self, image: Union[Path, Image.Image]) -> Image.Image:
if isinstance(image, (str, Path)):
image = Image.open(image)
assert isinstance(image, Image.Image)
bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white')
blend_img = self.center_align(bg_img, image, fit=True)
return blend_img
@staticmethod
def pil_image_to_array(blend_img: Image.Image) -> np.ndarray:
normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1]
return normalized_array
def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]:
imagefont = ImageFont.truetype(str(font_file_path), size=font_size)
example_characters = EXAMPLE_CHARACTERS
font_images: List[Image.Image] = []
for character in example_characters:
x, y, _, _ = imagefont.getbbox(character)
img = Image.new(imgmode, (x + padding, y + padding), color='white')
draw = ImageDraw.Draw(img)
# bbox = draw.textbbox((0,0), character, font=imagefont)
# w = bbox[2] - bbox[0]
# h = bbox[3] - bbox[1]
w, h = draw.textsize(character, font=imagefont)
img = Image.new(imgmode, (w + padding, h + padding), color='white')
draw = ImageDraw.Draw(img)
draw.text(position, text=character, font=imagefont, fill='black')
img = img.convert(imgmode)
font_images.append(img)
return font_images
@staticmethod
def get_hex_from_char(char: str) -> str:
assert len(char) == 1
return f"{ord(char):04X}".upper() # 4-digit hex string
@torch.no_grad()
def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]:
assert len(content_char) > 0
content_char = content_char[:1] # only get the first character if the length > 1
char_hex = self.get_hex_from_char(content_char)
if char_hex not in self.content_character_dict:
raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!")
content_image = self.set_image(self.content_character_dict[char_hex])
style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font))
style_images: List[Image.Image] = [self.set_image(image) for image in style_images]
content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W
style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch
content_input_tensor = torch.from_numpy(content_image_array).to(self.device)
style_input_tensor = torch.from_numpy(style_images_array).to(self.device)
generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor))
generated_images = torch.clip(generated_images, 0, 1)
assert generated_images.size(0) == 1
generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W
return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L')
if __name__ == '__main__':
hp = OmegaConf.load("config/models/google-font.yaml")
checkpoint_path = "epoch=199-step=257400.ckpt"
content_image_dir = "../DATA/NotoSans"
servicer = InferenceServicer(hp, checkpoint_path, content_image_dir)
style_font = "example_fonts/MaShanZheng-Regular.ttf"
content_image, style_images, result = servicer.inference("7", style_font)
content_image.save("result_content.png")
for idx, style_image in enumerate(style_images):
style_image.save(f"result_style_{idx:02d}.png")
result.save("result_generated.png") |