""" This script provides an exmaple to wrap TencentPretrain for generation. Given the beginning of a text, language model generates the rest. """ import sys import os import argparse import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) from tencentpretrain.embeddings import * from tencentpretrain.layers import * from tencentpretrain.encoders import * from tencentpretrain.targets import * from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.model_loader import load_model from tencentpretrain.opts import model_opts, tokenizer_opts from scripts.generate_lm import top_k_top_p_filtering from tencentpretrain.utils.image_tokenizer import * class GenerateLm(torch.nn.Module): def __init__(self, args): super(GenerateLm, self).__init__() self.embedding = Embedding(args) for embedding_name in args.embedding: tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) self.embedding.update(tmp_emb, embedding_name) self.encoder = str2encoder[args.encoder](args) self.target = Target() self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm") def forward(self, src, seg): emb = self.embedding(src, seg) output = self.encoder(emb, seg) output = self.target.lm.output_layer(output) return output if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Model options. model_opts(parser) # Inference options. parser.add_argument("--load_model_path", default=None, type=str, help="Path of the input model.") parser.add_argument("--config_path", type=str, required=True, help="Path of the config file.") parser.add_argument("--seq_length", type=int, default=48, help="Sequence length.") parser.add_argument("--samples_num", type=int, default=10, help="Number of iterations for sampling.") parser.add_argument("--prompt", choices=["to_attributes", "to_caption", "to_image"], default="to_attributes", help="Prompt that indicates the output format.") parser.add_argument("--text_prefix_path", type=str, default=None, help="Text prefix for to_attributes and to_image.") parser.add_argument("--image_prefix_path", type=str, default=None, help="Input image path.") parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) tokenizer_opts(parser) args = parser.parse_args() args.batch_size = 1 args = load_hyperparam(args) args.tokenizer = str2tokenizer[args.tokenizer](args) args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.text_prefix = None if args.text_prefix_path is not None: with open(args.text_prefix_path, "r") as f: args.text_prefix = f.readline() def preprocess_vqgan(x): x = 2.*x - 1. return x def convert_color(image, channel): if channel == 3: if image.mode == "RGBA": r, g, b, a = image.split() image = Image.merge("RGB", (r, g, b)) elif image.mode != "RGB": image = image.convert("RGBA") r, g, b, a = image.split() image = Image.merge("RGB", (r, g, b)) elif channel == 1: image = image.convert("L") return image transform = transforms.Compose([ transforms.Lambda(lambda img: convert_color(img, 3)), transforms.Resize((args.image_height, args.image_width)), transforms.ToTensor(), transforms.Lambda(lambda x: preprocess_vqgan(x)), ]) model = GenerateLm(args) model = load_model(model, args.load_model_path) model = model.to(args.device) model.eval() vqgan = build_vqgan_model(args) vqgan = vqgan.to(args.device) prompt = " ".join(args.prompt.split("_")) PAD_ID, CLS_ID, SEP_ID, MASK_ID = 0, 101, 102, 103 if args.image_prefix_path is None: src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN] + args.tokenizer.tokenize(args.text_prefix) + [SEP_TOKEN]) if len(src) > 64: src = src[:64] seg = [1] * len(src) else: image = Image.open(args.image_prefix_path) image = transform(image).to(args.device) image_token = image_tokenize(vqgan, image) p_src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN]) src = p_src + [i + args.tokenizer.vocab_bias for i in image_token] + [SEP_ID] seg = [1] * len(src) if args.text_prefix is not None: attr_prompt_src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(args.text_prefix)) src = src + attr_prompt_src seg = seg + [2] * len(attr_prompt_src) beginning_length = len(src) for r in range(args.samples_num): src_tensor, seg_tensor = torch.LongTensor([src]), torch.LongTensor([seg]) for i in range(args.seq_length - beginning_length): src_tensor = src_tensor.to(args.device) seg_tensor = seg_tensor.to(args.device) with torch.no_grad(): output = model(src_tensor, seg_tensor) next_token_logits = output[0][-1] / args.temperature filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1) seg_tensor = torch.cat([seg_tensor, torch.tensor([[2]], device=args.device)], dim=1) if args.image_prefix_path is not None: text_id = [str(token_id.item()) for token_id in src_tensor[0][beginning_length:]] text_id = " ".join(text_id) text_id = text_id.split(str(SEP_ID))[0].strip().split(" ") generated_sentence = " ".join( args.tokenizer.convert_ids_to_tokens([int(i) for i in text_id]) ) print("output " + str(r) + ":" + "\n") print(generated_sentence + "\n") else: image_id = [token_id.item() for token_id in src_tensor[0][beginning_length:]] img_length = (args.image_height // args.image_tokenizer["frame_size"]) ** 2 img_seg = [i - args.tokenizer.vocab_bias for i in image_id[: img_length]] image_detokenize(vqgan, img_seg, args.image_tokenizer["image_vocab_size"], False, "output-" + str(r) + ".jpg")