Spaces:
Paused
Paused
add reference code from vllm
Browse files- .gitignore +1 -0
- app.py +34 -74
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
notes.py
|
app.py
CHANGED
@@ -14,7 +14,7 @@ import spaces
|
|
14 |
import math
|
15 |
from typing import List, Optional, Tuple
|
16 |
|
17 |
-
title = "# 🙋🏻♂️Welcome to Tonic's Pixtral Model Demo"
|
18 |
description = """
|
19 |
This demo showcases two capabilities of the Pixtral model:
|
20 |
1. Image-to-Text Generation
|
@@ -25,6 +25,7 @@ This demo showcases two capabilities of the Pixtral model:
|
|
25 |
"""
|
26 |
|
27 |
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
|
|
|
28 |
with open(f'{model_path}/params.json', 'r') as f:
|
29 |
params = json.load(f)
|
30 |
|
@@ -40,32 +41,16 @@ class RMSNorm(nn.Module):
|
|
40 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
42 |
|
43 |
-
def precompute_freqs_cis_2d(
|
44 |
-
dim: int,
|
45 |
-
height: int,
|
46 |
-
width: int,
|
47 |
-
theta: float,
|
48 |
-
) -> torch.Tensor:
|
49 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
50 |
-
h = torch.arange(height
|
51 |
-
w = torch.arange(width
|
52 |
-
|
53 |
freqs_h = torch.outer(h, freqs[::2]).float()
|
54 |
freqs_w = torch.outer(w, freqs[1::2]).float()
|
55 |
-
freqs_2d = torch.cat(
|
56 |
-
[
|
57 |
-
freqs_h[:, None, :].repeat(1, width, 1),
|
58 |
-
freqs_w[None, :, :].repeat(height, 1, 1),
|
59 |
-
],
|
60 |
-
dim=-1,
|
61 |
-
)
|
62 |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
63 |
|
64 |
-
def apply_rotary_emb_vit(
|
65 |
-
xq: torch.Tensor,
|
66 |
-
xk: torch.Tensor,
|
67 |
-
freqs_cis: torch.Tensor,
|
68 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
69 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
70 |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
71 |
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
|
@@ -78,7 +63,6 @@ class Attention(nn.Module):
|
|
78 |
super().__init__()
|
79 |
self.n_heads = args['num_attention_heads']
|
80 |
self.head_dim = args['hidden_size'] // args['num_attention_heads']
|
81 |
-
|
82 |
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
83 |
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
84 |
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
@@ -86,14 +70,11 @@ class Attention(nn.Module):
|
|
86 |
|
87 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
88 |
batch, patches, _ = x.shape
|
89 |
-
|
90 |
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
91 |
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
92 |
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
93 |
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
94 |
-
|
95 |
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
96 |
-
|
97 |
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
98 |
attn = F.softmax(scores, dim=-1)
|
99 |
out = torch.matmul(attn, v)
|
@@ -119,9 +100,9 @@ class TransformerBlock(nn.Module):
|
|
119 |
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
|
120 |
|
121 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
122 |
-
r = self.attention
|
123 |
h = x + r
|
124 |
-
r = self.feed_forward
|
125 |
out = h + r
|
126 |
return out
|
127 |
|
@@ -129,16 +110,9 @@ class VisionTransformer(nn.Module):
|
|
129 |
def __init__(self, args):
|
130 |
super().__init__()
|
131 |
self.args = args
|
132 |
-
self.patch_conv = nn.Conv2d(
|
133 |
-
in_channels=args['num_channels'],
|
134 |
-
out_channels=args['hidden_size'],
|
135 |
-
kernel_size=args['patch_size'],
|
136 |
-
stride=args['patch_size'],
|
137 |
-
bias=False,
|
138 |
-
)
|
139 |
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
|
140 |
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
|
141 |
-
|
142 |
self.max_patches_per_side = args['image_size'] // args['patch_size']
|
143 |
self._freqs_cis = None
|
144 |
|
@@ -157,11 +131,9 @@ class VisionTransformer(nn.Module):
|
|
157 |
x = self.patch_conv(x)
|
158 |
x = x.flatten(2).transpose(1, 2)
|
159 |
x = self.ln_pre(x)
|
160 |
-
|
161 |
freqs_cis = self.freqs_cis
|
162 |
for layer in self.transformer:
|
163 |
x = layer(x, freqs_cis=freqs_cis)
|
164 |
-
|
165 |
return x
|
166 |
|
167 |
class VisionLanguageAdapter(nn.Module):
|
@@ -180,9 +152,7 @@ class PixtralModel(nn.Module):
|
|
180 |
self.vision_encoder = VisionTransformer(params['vision_encoder'])
|
181 |
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
|
182 |
self.language_model = nn.TransformerDecoder(
|
183 |
-
nn.TransformerDecoderLayer(d_model=params['dim'],
|
184 |
-
nhead=params['n_heads'],
|
185 |
-
dim_feedforward=params['hidden_dim']),
|
186 |
num_layers=params['n_layers']
|
187 |
)
|
188 |
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
|
@@ -201,12 +171,10 @@ class PixtralModel(nn.Module):
|
|
201 |
|
202 |
def load_model(params, model_path):
|
203 |
model = PixtralModel(params)
|
204 |
-
|
205 |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
|
206 |
for name, param in model.named_parameters():
|
207 |
if name in f.keys():
|
208 |
param.data = f.get_tensor(name)
|
209 |
-
|
210 |
model.eval()
|
211 |
return model
|
212 |
|
@@ -224,53 +192,45 @@ def preprocess_image(image):
|
|
224 |
@spaces.GPU(duration=120)
|
225 |
def generate_text(image, prompt, max_tokens):
|
226 |
try:
|
227 |
-
|
|
|
|
|
228 |
|
229 |
tokenized = tokenizer.encode_chat_completion(
|
230 |
ChatCompletionRequest(
|
231 |
-
messages=[
|
232 |
-
UserMessage(
|
233 |
-
content=[
|
234 |
-
TextChunk(text=prompt),
|
235 |
-
ImageChunk(image=image),
|
236 |
-
]
|
237 |
-
)
|
238 |
-
],
|
239 |
model="pixtral",
|
240 |
)
|
241 |
)
|
242 |
-
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).
|
243 |
-
|
244 |
-
|
245 |
-
model
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
if next_token.item() == tokenizer.eos_token_id:
|
252 |
-
break
|
253 |
-
model.cpu()
|
254 |
|
255 |
generated_text = tokenizer.decode(input_ids[0].tolist())
|
256 |
-
|
|
|
257 |
except Exception as e:
|
258 |
return f"Error: {str(e)}", 0, 0
|
259 |
|
260 |
@spaces.GPU(duration=60)
|
261 |
def calculate_similarity(image1, image2):
|
262 |
try:
|
263 |
-
|
264 |
-
|
|
|
|
|
265 |
|
266 |
-
|
267 |
-
|
268 |
-
embedding1 = model(tensor1).mean(dim=1) # Average over spatial dimensions
|
269 |
-
embedding2 = model(tensor2).mean(dim=1)
|
270 |
-
model.cpu()
|
271 |
|
272 |
similarity = F.cosine_similarity(embedding1, embedding2).item()
|
273 |
-
|
274 |
return similarity
|
275 |
except Exception as e:
|
276 |
return f"Error: {str(e)}"
|
@@ -299,7 +259,7 @@ with gr.Blocks() as demo:
|
|
299 |
with gr.Column():
|
300 |
input_image = gr.Image(type="pil", label="Input Image")
|
301 |
input_prompt = gr.Textbox(label="Prompt")
|
302 |
-
max_tokens_slider = gr.Slider(minimum=
|
303 |
submit_btn = gr.Button("Generate Text")
|
304 |
|
305 |
with gr.Column():
|
|
|
14 |
import math
|
15 |
from typing import List, Optional, Tuple
|
16 |
|
17 |
+
title = "# **WIP / DEMO** 🙋🏻♂️Welcome to Tonic's Pixtral Model Demo"
|
18 |
description = """
|
19 |
This demo showcases two capabilities of the Pixtral model:
|
20 |
1. Image-to-Text Generation
|
|
|
25 |
"""
|
26 |
|
27 |
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
|
28 |
+
|
29 |
with open(f'{model_path}/params.json', 'r') as f:
|
30 |
params = json.load(f)
|
31 |
|
|
|
41 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
43 |
|
44 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
45 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
46 |
+
h = torch.arange(height)
|
47 |
+
w = torch.arange(width)
|
|
|
48 |
freqs_h = torch.outer(h, freqs[::2]).float()
|
49 |
freqs_w = torch.outer(w, freqs[1::2]).float()
|
50 |
+
freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
52 |
|
53 |
+
def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
54 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
55 |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
56 |
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
|
|
|
63 |
super().__init__()
|
64 |
self.n_heads = args['num_attention_heads']
|
65 |
self.head_dim = args['hidden_size'] // args['num_attention_heads']
|
|
|
66 |
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
67 |
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
68 |
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
|
|
70 |
|
71 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
72 |
batch, patches, _ = x.shape
|
|
|
73 |
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
74 |
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
75 |
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
76 |
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
|
77 |
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
|
|
78 |
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
79 |
attn = F.softmax(scores, dim=-1)
|
80 |
out = torch.matmul(attn, v)
|
|
|
100 |
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
|
101 |
|
102 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
103 |
+
r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis)
|
104 |
h = x + r
|
105 |
+
r = self.feed_forward(self.ffn_norm(h))
|
106 |
out = h + r
|
107 |
return out
|
108 |
|
|
|
110 |
def __init__(self, args):
|
111 |
super().__init__()
|
112 |
self.args = args
|
113 |
+
self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
|
115 |
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
|
|
|
116 |
self.max_patches_per_side = args['image_size'] // args['patch_size']
|
117 |
self._freqs_cis = None
|
118 |
|
|
|
131 |
x = self.patch_conv(x)
|
132 |
x = x.flatten(2).transpose(1, 2)
|
133 |
x = self.ln_pre(x)
|
|
|
134 |
freqs_cis = self.freqs_cis
|
135 |
for layer in self.transformer:
|
136 |
x = layer(x, freqs_cis=freqs_cis)
|
|
|
137 |
return x
|
138 |
|
139 |
class VisionLanguageAdapter(nn.Module):
|
|
|
152 |
self.vision_encoder = VisionTransformer(params['vision_encoder'])
|
153 |
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
|
154 |
self.language_model = nn.TransformerDecoder(
|
155 |
+
nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']),
|
|
|
|
|
156 |
num_layers=params['n_layers']
|
157 |
)
|
158 |
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
|
|
|
171 |
|
172 |
def load_model(params, model_path):
|
173 |
model = PixtralModel(params)
|
|
|
174 |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
|
175 |
for name, param in model.named_parameters():
|
176 |
if name in f.keys():
|
177 |
param.data = f.get_tensor(name)
|
|
|
178 |
model.eval()
|
179 |
return model
|
180 |
|
|
|
192 |
@spaces.GPU(duration=120)
|
193 |
def generate_text(image, prompt, max_tokens):
|
194 |
try:
|
195 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
196 |
+
image_tensor = preprocess_image(image).to(device)
|
197 |
+
model.to(device)
|
198 |
|
199 |
tokenized = tokenizer.encode_chat_completion(
|
200 |
ChatCompletionRequest(
|
201 |
+
messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
model="pixtral",
|
203 |
)
|
204 |
)
|
205 |
+
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device)
|
206 |
+
|
207 |
+
for _ in range(max_tokens):
|
208 |
+
logits = model(image_tensor, input_ids)
|
209 |
+
next_token_logits = logits[0, -1, :]
|
210 |
+
next_token = torch.argmax(next_token_logits, dim=-1)
|
211 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
|
212 |
+
if next_token.item() == tokenizer.eos_token_id:
|
213 |
+
break
|
|
|
|
|
|
|
214 |
|
215 |
generated_text = tokenizer.decode(input_ids[0].tolist())
|
216 |
+
# model.to("cpu")
|
217 |
+
return generated_text, len(input_ids[0]), 1
|
218 |
except Exception as e:
|
219 |
return f"Error: {str(e)}", 0, 0
|
220 |
|
221 |
@spaces.GPU(duration=60)
|
222 |
def calculate_similarity(image1, image2):
|
223 |
try:
|
224 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
225 |
+
tensor1 = preprocess_image(image1).to(device)
|
226 |
+
tensor2 = preprocess_image(image2).to(device)
|
227 |
+
model.to(device)
|
228 |
|
229 |
+
embedding1 = model(tensor1).mean(dim=1)
|
230 |
+
embedding2 = model(tensor2).mean(dim=1)
|
|
|
|
|
|
|
231 |
|
232 |
similarity = F.cosine_similarity(embedding1, embedding2).item()
|
233 |
+
# model.to("cpu")
|
234 |
return similarity
|
235 |
except Exception as e:
|
236 |
return f"Error: {str(e)}"
|
|
|
259 |
with gr.Column():
|
260 |
input_image = gr.Image(type="pil", label="Input Image")
|
261 |
input_prompt = gr.Textbox(label="Prompt")
|
262 |
+
max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens")
|
263 |
submit_btn = gr.Button("Generate Text")
|
264 |
|
265 |
with gr.Column():
|