Tonic commited on
Commit
24b8c6e
1 Parent(s): 2bdacd4

add reference code from vllm

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +34 -74
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ notes.py
app.py CHANGED
@@ -14,7 +14,7 @@ import spaces
14
  import math
15
  from typing import List, Optional, Tuple
16
 
17
- title = "# 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo"
18
  description = """
19
  This demo showcases two capabilities of the Pixtral model:
20
  1. Image-to-Text Generation
@@ -25,6 +25,7 @@ This demo showcases two capabilities of the Pixtral model:
25
  """
26
 
27
  model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
 
28
  with open(f'{model_path}/params.json', 'r') as f:
29
  params = json.load(f)
30
 
@@ -40,32 +41,16 @@ class RMSNorm(nn.Module):
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
42
 
43
- def precompute_freqs_cis_2d(
44
- dim: int,
45
- height: int,
46
- width: int,
47
- theta: float,
48
- ) -> torch.Tensor:
49
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
50
- h = torch.arange(height, device=freqs.device)
51
- w = torch.arange(width, device=freqs.device)
52
-
53
  freqs_h = torch.outer(h, freqs[::2]).float()
54
  freqs_w = torch.outer(w, freqs[1::2]).float()
55
- freqs_2d = torch.cat(
56
- [
57
- freqs_h[:, None, :].repeat(1, width, 1),
58
- freqs_w[None, :, :].repeat(height, 1, 1),
59
- ],
60
- dim=-1,
61
- )
62
  return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
63
 
64
- def apply_rotary_emb_vit(
65
- xq: torch.Tensor,
66
- xk: torch.Tensor,
67
- freqs_cis: torch.Tensor,
68
- ) -> Tuple[torch.Tensor, torch.Tensor]:
69
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
70
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
71
  freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
@@ -78,7 +63,6 @@ class Attention(nn.Module):
78
  super().__init__()
79
  self.n_heads = args['num_attention_heads']
80
  self.head_dim = args['hidden_size'] // args['num_attention_heads']
81
-
82
  self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
83
  self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
84
  self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
@@ -86,14 +70,11 @@ class Attention(nn.Module):
86
 
87
  def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
88
  batch, patches, _ = x.shape
89
-
90
  q, k, v = self.wq(x), self.wk(x), self.wv(x)
91
  q = q.reshape(batch, patches, self.n_heads, self.head_dim)
92
  k = k.reshape(batch, patches, self.n_heads, self.head_dim)
93
  v = v.reshape(batch, patches, self.n_heads, self.head_dim)
94
-
95
  q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
96
-
97
  scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
98
  attn = F.softmax(scores, dim=-1)
99
  out = torch.matmul(attn, v)
@@ -119,9 +100,9 @@ class TransformerBlock(nn.Module):
119
  self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
120
 
121
  def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
122
- r = self.attention.forward(self.attention_norm(x), freqs_cis=freqs_cis)
123
  h = x + r
124
- r = self.feed_forward.forward(self.ffn_norm(h))
125
  out = h + r
126
  return out
127
 
@@ -129,16 +110,9 @@ class VisionTransformer(nn.Module):
129
  def __init__(self, args):
130
  super().__init__()
131
  self.args = args
132
- self.patch_conv = nn.Conv2d(
133
- in_channels=args['num_channels'],
134
- out_channels=args['hidden_size'],
135
- kernel_size=args['patch_size'],
136
- stride=args['patch_size'],
137
- bias=False,
138
- )
139
  self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
140
  self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
141
-
142
  self.max_patches_per_side = args['image_size'] // args['patch_size']
143
  self._freqs_cis = None
144
 
@@ -157,11 +131,9 @@ class VisionTransformer(nn.Module):
157
  x = self.patch_conv(x)
158
  x = x.flatten(2).transpose(1, 2)
159
  x = self.ln_pre(x)
160
-
161
  freqs_cis = self.freqs_cis
162
  for layer in self.transformer:
163
  x = layer(x, freqs_cis=freqs_cis)
164
-
165
  return x
166
 
167
  class VisionLanguageAdapter(nn.Module):
@@ -180,9 +152,7 @@ class PixtralModel(nn.Module):
180
  self.vision_encoder = VisionTransformer(params['vision_encoder'])
181
  self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
182
  self.language_model = nn.TransformerDecoder(
183
- nn.TransformerDecoderLayer(d_model=params['dim'],
184
- nhead=params['n_heads'],
185
- dim_feedforward=params['hidden_dim']),
186
  num_layers=params['n_layers']
187
  )
188
  self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
@@ -201,12 +171,10 @@ class PixtralModel(nn.Module):
201
 
202
  def load_model(params, model_path):
203
  model = PixtralModel(params)
204
-
205
  with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
206
  for name, param in model.named_parameters():
207
  if name in f.keys():
208
  param.data = f.get_tensor(name)
209
-
210
  model.eval()
211
  return model
212
 
@@ -224,53 +192,45 @@ def preprocess_image(image):
224
  @spaces.GPU(duration=120)
225
  def generate_text(image, prompt, max_tokens):
226
  try:
227
- image_tensor = preprocess_image(image).cuda()
 
 
228
 
229
  tokenized = tokenizer.encode_chat_completion(
230
  ChatCompletionRequest(
231
- messages=[
232
- UserMessage(
233
- content=[
234
- TextChunk(text=prompt),
235
- ImageChunk(image=image),
236
- ]
237
- )
238
- ],
239
  model="pixtral",
240
  )
241
  )
242
- input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).cuda()
243
-
244
- with torch.no_grad():
245
- model.cuda()
246
- for _ in range(max_tokens):
247
- logits = model(image_tensor, input_ids)
248
- next_token_logits = logits[0, -1, :]
249
- next_token = torch.argmax(next_token_logits, dim=-1)
250
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
251
- if next_token.item() == tokenizer.eos_token_id:
252
- break
253
- model.cpu()
254
 
255
  generated_text = tokenizer.decode(input_ids[0].tolist())
256
- return generated_text, len(input_ids[0]), 1 # 1 image processed
 
257
  except Exception as e:
258
  return f"Error: {str(e)}", 0, 0
259
 
260
  @spaces.GPU(duration=60)
261
  def calculate_similarity(image1, image2):
262
  try:
263
- tensor1 = preprocess_image(image1).cuda()
264
- tensor2 = preprocess_image(image2).cuda()
 
 
265
 
266
- with torch.no_grad():
267
- model.cuda()
268
- embedding1 = model(tensor1).mean(dim=1) # Average over spatial dimensions
269
- embedding2 = model(tensor2).mean(dim=1)
270
- model.cpu()
271
 
272
  similarity = F.cosine_similarity(embedding1, embedding2).item()
273
-
274
  return similarity
275
  except Exception as e:
276
  return f"Error: {str(e)}"
@@ -299,7 +259,7 @@ with gr.Blocks() as demo:
299
  with gr.Column():
300
  input_image = gr.Image(type="pil", label="Input Image")
301
  input_prompt = gr.Textbox(label="Prompt")
302
- max_tokens_slider = gr.Slider(minimum=60, maximum=1600, value=90, step=5, label="Max Tokens")
303
  submit_btn = gr.Button("Generate Text")
304
 
305
  with gr.Column():
 
14
  import math
15
  from typing import List, Optional, Tuple
16
 
17
+ title = "# **WIP / DEMO** 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo"
18
  description = """
19
  This demo showcases two capabilities of the Pixtral model:
20
  1. Image-to-Text Generation
 
25
  """
26
 
27
  model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
28
+
29
  with open(f'{model_path}/params.json', 'r') as f:
30
  params = json.load(f)
31
 
 
41
  def forward(self, x: torch.Tensor) -> torch.Tensor:
42
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
43
 
44
+ def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
 
 
 
 
 
45
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
46
+ h = torch.arange(height)
47
+ w = torch.arange(width)
 
48
  freqs_h = torch.outer(h, freqs[::2]).float()
49
  freqs_w = torch.outer(w, freqs[1::2]).float()
50
+ freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1)
 
 
 
 
 
 
51
  return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
52
 
53
+ def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
54
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
55
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
56
  freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
 
63
  super().__init__()
64
  self.n_heads = args['num_attention_heads']
65
  self.head_dim = args['hidden_size'] // args['num_attention_heads']
 
66
  self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
67
  self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
68
  self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
 
70
 
71
  def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
72
  batch, patches, _ = x.shape
 
73
  q, k, v = self.wq(x), self.wk(x), self.wv(x)
74
  q = q.reshape(batch, patches, self.n_heads, self.head_dim)
75
  k = k.reshape(batch, patches, self.n_heads, self.head_dim)
76
  v = v.reshape(batch, patches, self.n_heads, self.head_dim)
 
77
  q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
 
78
  scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
79
  attn = F.softmax(scores, dim=-1)
80
  out = torch.matmul(attn, v)
 
100
  self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
101
 
102
  def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
103
+ r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis)
104
  h = x + r
105
+ r = self.feed_forward(self.ffn_norm(h))
106
  out = h + r
107
  return out
108
 
 
110
  def __init__(self, args):
111
  super().__init__()
112
  self.args = args
113
+ self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False)
 
 
 
 
 
 
114
  self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
115
  self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
 
116
  self.max_patches_per_side = args['image_size'] // args['patch_size']
117
  self._freqs_cis = None
118
 
 
131
  x = self.patch_conv(x)
132
  x = x.flatten(2).transpose(1, 2)
133
  x = self.ln_pre(x)
 
134
  freqs_cis = self.freqs_cis
135
  for layer in self.transformer:
136
  x = layer(x, freqs_cis=freqs_cis)
 
137
  return x
138
 
139
  class VisionLanguageAdapter(nn.Module):
 
152
  self.vision_encoder = VisionTransformer(params['vision_encoder'])
153
  self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
154
  self.language_model = nn.TransformerDecoder(
155
+ nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']),
 
 
156
  num_layers=params['n_layers']
157
  )
158
  self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
 
171
 
172
  def load_model(params, model_path):
173
  model = PixtralModel(params)
 
174
  with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
175
  for name, param in model.named_parameters():
176
  if name in f.keys():
177
  param.data = f.get_tensor(name)
 
178
  model.eval()
179
  return model
180
 
 
192
  @spaces.GPU(duration=120)
193
  def generate_text(image, prompt, max_tokens):
194
  try:
195
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
+ image_tensor = preprocess_image(image).to(device)
197
+ model.to(device)
198
 
199
  tokenized = tokenizer.encode_chat_completion(
200
  ChatCompletionRequest(
201
+ messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])],
 
 
 
 
 
 
 
202
  model="pixtral",
203
  )
204
  )
205
+ input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device)
206
+
207
+ for _ in range(max_tokens):
208
+ logits = model(image_tensor, input_ids)
209
+ next_token_logits = logits[0, -1, :]
210
+ next_token = torch.argmax(next_token_logits, dim=-1)
211
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
212
+ if next_token.item() == tokenizer.eos_token_id:
213
+ break
 
 
 
214
 
215
  generated_text = tokenizer.decode(input_ids[0].tolist())
216
+ # model.to("cpu")
217
+ return generated_text, len(input_ids[0]), 1
218
  except Exception as e:
219
  return f"Error: {str(e)}", 0, 0
220
 
221
  @spaces.GPU(duration=60)
222
  def calculate_similarity(image1, image2):
223
  try:
224
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ tensor1 = preprocess_image(image1).to(device)
226
+ tensor2 = preprocess_image(image2).to(device)
227
+ model.to(device)
228
 
229
+ embedding1 = model(tensor1).mean(dim=1)
230
+ embedding2 = model(tensor2).mean(dim=1)
 
 
 
231
 
232
  similarity = F.cosine_similarity(embedding1, embedding2).item()
233
+ # model.to("cpu")
234
  return similarity
235
  except Exception as e:
236
  return f"Error: {str(e)}"
 
259
  with gr.Column():
260
  input_image = gr.Image(type="pil", label="Input Image")
261
  input_prompt = gr.Textbox(label="Prompt")
262
+ max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens")
263
  submit_btn = gr.Button("Generate Text")
264
 
265
  with gr.Column():