File size: 4,662 Bytes
f7a5cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Any, Dict, List, Tuple

import clip
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
import torch
from torchtyping import TensorType
from torch.utils.data import DataLoader
import torch.nn.functional as F

from src.diffuser import Diffuser
from src.datasets.multimodal_dataset import MultimodalDataset

# ------------------------------------------------------------------------------------- #

batch_size, context_length = None, None
collate_fn = DataLoader([]).collate_fn

# ------------------------------------------------------------------------------------- #


def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            batch[key] = value.to(device)
    return batch


def load_clip_model(version: str, device: str) -> clip.model.CLIP:
    model, _ = clip.load(version, device=device, jit=False)
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    return model


def encode_text(
    caption_raws: List[str],  # batch_size
    clip_model: clip.model.CLIP,
    max_token_length: int,
    device: str,
) -> TensorType["batch_size", "context_length"]:
    if max_token_length is not None:
        default_context_length = 77
        context_length = max_token_length + 2  # start_token + 20 + end_token
        assert context_length < default_context_length
        # [bs, context_length] # if n_tokens > context_length -> will truncate
        texts = clip.tokenize(
            caption_raws, context_length=context_length, truncate=True
        )
        zero_pad = torch.zeros(
            [texts.shape[0], default_context_length - context_length],
            dtype=texts.dtype,
            device=texts.device,
        )
        texts = torch.cat([texts, zero_pad], dim=1)
    else:
        # [bs, context_length] # if n_tokens > 77 -> will truncate
        texts = clip.tokenize(caption_raws, truncate=True)

    # [batch_size, n_ctx, d_model]
    x = clip_model.token_embedding(texts.to(device)).type(clip_model.dtype)
    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clip_model.ln_final(x).type(clip_model.dtype)
    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest in each sequence)
    x_tokens = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)].float()
    x_seq = [x[k, : (m + 1)].float() for k, m in enumerate(texts.argmax(dim=-1))]

    return x_seq, x_tokens


def get_batch(
    prompt: str,
    sample_id: str,
    clip_model: clip.model.CLIP,
    dataset: MultimodalDataset,
    seq_feat: bool,
    device: torch.device,
) -> Dict[str, Any]:
    # Get base batch
    sample_index = dataset.root_filenames.index(sample_id)
    raw_batch = dataset[sample_index]
    batch = collate_fn([to_device(raw_batch, device)])

    # Encode text
    caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device)

    if seq_feat:
        caption_feat = caption_seq[0]
        caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0]))
        caption_feat = caption_feat.unsqueeze(0).permute(0, 2, 1)
    else:
        caption_feat = caption_tokens

    # Update batch
    batch["caption_raw"] = [prompt]
    batch["caption_feat"] = caption_feat

    return batch


def init(
    config_name: str,
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset, torch.device]:
    with initialize(version_base="1.3", config_path="../configs"):
        config = compose(config_name=config_name)

    OmegaConf.register_new_resolver("eval", eval)

    # Initialize model
    device = torch.device(config.compnode.device)
    diffuser = instantiate(config.diffuser)
    state_dict = torch.load(config.checkpoint_path, map_location=device)["state_dict"]
    state_dict["ema.initted"] = diffuser.ema.initted
    state_dict["ema.step"] = diffuser.ema.step
    diffuser.load_state_dict(state_dict, strict=False)
    diffuser.to(device).eval()

    # Initialize CLIP model
    clip_model = load_clip_model("ViT-B/32", device)

    # Initialize dataset
    config.dataset.char.load_vertices = True
    config.batch_size = 1
    dataset = instantiate(config.dataset)
    dataset.set_split("demo")
    diffuser.modalities = list(dataset.modality_datasets.keys())
    diffuser.get_matrix = dataset.get_matrix
    diffuser.v_get_matrix = dataset.get_matrix

    return diffuser, clip_model, dataset, device