File size: 7,170 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
  This script provides an exmaple to wrap TencentPretrain for generation.
  Given the beginning of a text, language model generates the rest.
"""
import sys
import os
import argparse

import torch
import torch.nn.functional as F

from torchvision import transforms
from PIL import Image

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.embeddings import *
from tencentpretrain.layers import *
from tencentpretrain.encoders import *
from tencentpretrain.targets import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import model_opts, tokenizer_opts
from scripts.generate_lm import top_k_top_p_filtering
from tencentpretrain.utils.image_tokenizer import *

class GenerateLm(torch.nn.Module):
    def __init__(self, args):
        super(GenerateLm, self).__init__()
        self.embedding = Embedding(args)
        for embedding_name in args.embedding:
            tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
            self.embedding.update(tmp_emb, embedding_name)
        self.encoder = str2encoder[args.encoder](args)
        self.target = Target()
        self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm")

    def forward(self, src, seg):
        emb = self.embedding(src, seg)
        output = self.encoder(emb, seg)
        output = self.target.lm.output_layer(output)
        return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)



    # Model options.
    model_opts(parser)

    # Inference options.
    parser.add_argument("--load_model_path", default=None, type=str,
                        help="Path of the input model.")
    parser.add_argument("--config_path", type=str, required=True,
                    help="Path of the config file.")
    parser.add_argument("--seq_length", type=int, default=48,
                        help="Sequence length.")

    parser.add_argument("--samples_num", type=int, default=10,
                        help="Number of iterations for sampling.")
    parser.add_argument("--prompt", choices=["to_attributes", "to_caption", "to_image"], default="to_attributes",
                        help="Prompt that indicates the output format.")
    parser.add_argument("--text_prefix_path",  type=str, default=None,
                        help="Text prefix for to_attributes and to_image.")
    parser.add_argument("--image_prefix_path",  type=str, default=None,
                        help="Input image path.")
    parser.add_argument("--top_k", type=int, default=70)
    parser.add_argument("--top_p", type=float, default=0)
    parser.add_argument("--temperature", type=float, default=1.0)

    tokenizer_opts(parser)

    args = parser.parse_args()

    args.batch_size = 1

    args = load_hyperparam(args)

    args.tokenizer = str2tokenizer[args.tokenizer](args)
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args.text_prefix = None
    if args.text_prefix_path is not None:
        with open(args.text_prefix_path, "r") as f:
            args.text_prefix = f.readline()

    def preprocess_vqgan(x):
        x = 2.*x - 1.
        return x

    def convert_color(image, channel):
        if channel == 3:
            if image.mode == "RGBA":
                r, g, b, a = image.split()
                image = Image.merge("RGB", (r, g, b))
            elif image.mode != "RGB":
                image = image.convert("RGBA")
                r, g, b, a = image.split()
                image = Image.merge("RGB", (r, g, b))
        elif channel == 1:
            image = image.convert("L")

        return image

    transform = transforms.Compose([
        transforms.Lambda(lambda img: convert_color(img, 3)),
        transforms.Resize((args.image_height, args.image_width)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: preprocess_vqgan(x)),
    ])

    model = GenerateLm(args)
    model = load_model(model, args.load_model_path)
    model = model.to(args.device)
    model.eval()

    vqgan = build_vqgan_model(args)
    vqgan = vqgan.to(args.device)

    prompt = " ".join(args.prompt.split("_"))

    PAD_ID, CLS_ID, SEP_ID, MASK_ID = 0, 101, 102, 103
    if args.image_prefix_path is None:
        src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN] +
                                                   args.tokenizer.tokenize(args.text_prefix) + [SEP_TOKEN])
        if len(src) > 64:
            src = src[:64]
        seg = [1] * len(src)
    else:
        image = Image.open(args.image_prefix_path)
        image = transform(image).to(args.device)
        image_token = image_tokenize(vqgan, image)

        p_src =  args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN])
        src = p_src + [i + args.tokenizer.vocab_bias for i in image_token] + [SEP_ID]
        seg = [1] * len(src)

        if args.text_prefix is not None:
            attr_prompt_src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(args.text_prefix))
            src = src + attr_prompt_src
            seg = seg + [2] * len(attr_prompt_src)


    beginning_length = len(src)

    for r in range(args.samples_num):
        src_tensor, seg_tensor = torch.LongTensor([src]), torch.LongTensor([seg])
        for i in range(args.seq_length - beginning_length):
            src_tensor = src_tensor.to(args.device)
            seg_tensor = seg_tensor.to(args.device)
            with torch.no_grad():
                output = model(src_tensor, seg_tensor)

            next_token_logits = output[0][-1] / args.temperature
            filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1)
            seg_tensor = torch.cat([seg_tensor, torch.tensor([[2]], device=args.device)], dim=1)

        if args.image_prefix_path is not None:
            text_id = [str(token_id.item()) for token_id in src_tensor[0][beginning_length:]]
            text_id = " ".join(text_id)
            text_id = text_id.split(str(SEP_ID))[0].strip().split(" ")
            generated_sentence = " ".join(
                args.tokenizer.convert_ids_to_tokens([int(i) for i in text_id])
            )
            print("output " + str(r) + ":" + "\n")
            print(generated_sentence + "\n")
        else:
            image_id = [token_id.item() for token_id in src_tensor[0][beginning_length:]]
            img_length = (args.image_height // args.image_tokenizer["frame_size"]) ** 2
            img_seg = [i - args.tokenizer.vocab_bias for i in image_id[: img_length]]

            image_detokenize(vqgan, img_seg, args.image_tokenizer["image_vocab_size"], False, "output-" + str(r) + ".jpg")