keithhon commited on
Commit
7931c5f
1 Parent(s): 917ff2b

Upload dalle/utils/sampling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/utils/sampling.py +152 -0
dalle/utils/sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from torch.nn import functional as F
11
+
12
+
13
+ def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
14
+ if k is None:
15
+ return logits
16
+ else:
17
+ v, ix = torch.topk(logits, k)
18
+ out = logits.clone()
19
+ out[out < v[:, [-1]]] = -float('Inf')
20
+ return out
21
+
22
+
23
+ def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
24
+ if p is None:
25
+ return probs
26
+ else:
27
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
28
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
29
+
30
+ sorted_idx_remove_cond = cum_probs >= p
31
+
32
+ sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
33
+ sorted_idx_remove_cond[..., 0] = 0
34
+
35
+ indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
36
+ probs = probs.masked_fill(indices_to_remove, 0.0)
37
+ norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
38
+ return norm_probs
39
+
40
+
41
+ def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
42
+ device = inputs.device
43
+ if mode == '1d':
44
+ B, N = inputs.shape
45
+ xs_pos = torch.arange(N, device=device).repeat((B, 1))
46
+ elif mode == '2d':
47
+ B, H, W = inputs.shape
48
+ xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
49
+ xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
50
+ xs_pos = (xs_pos_h, xs_pos_w)
51
+ else:
52
+ raise ValueError('%s positional encoding invalid' % mode)
53
+ return xs_pos
54
+
55
+
56
+ @torch.no_grad()
57
+ def sampling(model: torch.nn.Module,
58
+ tokens: torch.LongTensor,
59
+ top_k: Optional[float] = None,
60
+ top_p: Optional[float] = None,
61
+ softmax_temperature: float = 1.0,
62
+ is_tqdm: bool = True,
63
+ use_fp16: bool = True,
64
+ max_seq_len: int = 256) -> torch.LongTensor:
65
+ code = None
66
+ past = None
67
+
68
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
69
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
70
+
71
+ for cnt, h in enumerate(pbar):
72
+ if code is None:
73
+ code_ = None
74
+ pos_enc_code_ = None
75
+ else:
76
+ code_ = code.clone().detach()
77
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
78
+ code_ = code_[:, cnt-1].unsqueeze(-1)
79
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
80
+
81
+ logits, present = model.sampling(images=code_,
82
+ texts=tokens,
83
+ pos_images=pos_enc_code_,
84
+ pos_texts=pos_enc_tokens,
85
+ use_fp16=use_fp16,
86
+ past=past)
87
+ logits = logits.to(dtype=torch.float32)
88
+ logits = logits / softmax_temperature
89
+
90
+ present = torch.stack(present).clone().detach()
91
+ if past is None:
92
+ past = [present]
93
+ else:
94
+ past.append(present)
95
+
96
+ logits = cutoff_topk_logits(logits, top_k)
97
+ probs = F.softmax(logits, dim=-1)
98
+ probs = cutoff_topp_probs(probs, top_p)
99
+
100
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
101
+ code = idx if code is None else torch.cat([code, idx], axis=1)
102
+
103
+ del past
104
+ return code
105
+
106
+
107
+ @torch.no_grad()
108
+ def sampling_igpt(model: torch.nn.Module,
109
+ sos: torch.FloatTensor,
110
+ top_k: Optional[float] = None,
111
+ top_p: Optional[float] = None,
112
+ softmax_temperature: float = 1.0,
113
+ is_tqdm: bool = True,
114
+ use_fp16: bool = True,
115
+ max_seq_len: int = 256) -> torch.LongTensor:
116
+ code = None
117
+ past = None
118
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
119
+
120
+ for cnt, h in enumerate(pbar):
121
+ if code is None:
122
+ code_ = None
123
+ pos_enc_code_ = None
124
+ else:
125
+ code_ = code.clone().detach()
126
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
127
+ code_ = code_[:, cnt-1].unsqueeze(-1)
128
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
129
+
130
+ logits, present = model.sampling(sos=sos,
131
+ codes=code_,
132
+ pos_codes=pos_enc_code_,
133
+ use_fp16=use_fp16,
134
+ past=past)
135
+ logits = logits.to(dtype=torch.float32)
136
+ logits = logits / softmax_temperature
137
+
138
+ present = torch.stack(present).clone().detach()
139
+ if past is None:
140
+ past = [present]
141
+ else:
142
+ past.append(present)
143
+
144
+ logits = cutoff_topk_logits(logits, top_k)
145
+ probs = F.softmax(logits, dim=-1)
146
+ probs = cutoff_topp_probs(probs, top_p)
147
+
148
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
149
+ code = idx if code is None else torch.cat([code, idx], axis=1)
150
+
151
+ del past
152
+ return code