ThomasSimonini HF staff commited on
Commit
abff794
1 Parent(s): ab91894

Upload moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +3 -184
moondream.py CHANGED
@@ -1,185 +1,3 @@
1
- """
2
- import torch
3
- from .vision_encoder import VisionEncoder
4
- from .configuration_moondream import MoondreamConfig
5
- from transformers import PreTrainedModel, TextIteratorStreamer
6
-
7
- from .modeling_phi import PhiForCausalLM
8
- from .configuration_moondream import PhiConfig
9
-
10
- class Moondream(PreTrainedModel):
11
- config_class = MoondreamConfig
12
- _supports_flash_attn_2 = True
13
-
14
- def __init__(self, config):
15
- super().__init__(config)
16
- self.vision_encoder = VisionEncoder(
17
- use_flash_attn=config._attn_implementation == "flash_attention_2"
18
- )
19
-
20
- if type(config.text_config) == dict:
21
- phi_config = PhiConfig(
22
- **config.text_config, attn_implementation=config._attn_implementation
23
- )
24
- else:
25
- phi_config = config.text_config
26
- self.text_model = PhiForCausalLM(phi_config)
27
-
28
- @property
29
- def device(self):
30
- return self.text_model.device
31
-
32
- def encode_image(self, image):
33
- with torch.no_grad():
34
- return self.vision_encoder(image)
35
-
36
- def input_embeds(self, prompt, image_embeds, tokenizer):
37
- def _tokenize(txt):
38
- return tokenizer(
39
- txt, return_tensors="pt", add_special_tokens=False
40
- ).input_ids.to(self.device)
41
-
42
- text_emb = self.text_model.get_input_embeddings()
43
-
44
- # Add BOS token
45
- embeds = []
46
- embeds.append(
47
- text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
48
- )
49
-
50
- if "<image>" not in prompt:
51
- embeds.append(text_emb(_tokenize(prompt)))
52
- else:
53
- assert prompt.count("<image>") == 1
54
- before, after = prompt.split("<image>")
55
- if len(before) > 0:
56
- embeds.append(text_emb(_tokenize(before)))
57
- embeds.append(image_embeds.to(self.device))
58
- if len(after) > 0:
59
- embeds.append(text_emb(_tokenize(after)))
60
-
61
- return torch.cat(embeds, dim=1)
62
-
63
- def get_input_embeddings(self):
64
- return self.text_model.get_input_embeddings()
65
-
66
- def generate(
67
- self,
68
- image_embeds,
69
- prompt,
70
- tokenizer,
71
- max_new_tokens=128,
72
- **kwargs,
73
- ):
74
- generate_config = {
75
- "eos_token_id": tokenizer.eos_token_id,
76
- "bos_token_id": tokenizer.bos_token_id,
77
- "pad_token_id": tokenizer.bos_token_id,
78
- "max_new_tokens": max_new_tokens,
79
- **kwargs,
80
- }
81
-
82
- with torch.no_grad():
83
- inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
84
- streamer = TextIteratorStreamer(tokenizer)
85
- output_ids = self.text_model.generate(
86
- inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
87
- )
88
- print("FINISHED")
89
-
90
- return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
91
-
92
- def answer_question(
93
- self,
94
- image_embeds,
95
- question,
96
- tokenizer,
97
- chat_history="",
98
- result_queue=None,
99
- **kwargs,
100
- ):
101
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
102
- answer = self.generate(
103
- image_embeds,
104
- prompt,
105
- tokenizer=tokenizer,
106
- max_new_tokens=512,
107
- **kwargs,
108
- )[0]
109
- cleaned_answer = answer.strip()
110
-
111
- # Use the result_queue to pass the result if it is provided
112
- if result_queue:
113
- result_queue.put(cleaned_answer)
114
- else:
115
- return cleaned_answer
116
-
117
- def batch_answer(
118
- self,
119
- images,
120
- prompts,
121
- tokenizer,
122
- **kwargs,
123
- ):
124
- image_embeds = self.encode_image(images)
125
-
126
- templated_prompts = [
127
- f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
128
- ]
129
- prompt_embs = [
130
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
131
- for prompt, image_embed in zip(templated_prompts, image_embeds)
132
- ]
133
-
134
- bos_emb = prompt_embs[0][0]
135
- max_len = max([p.shape[0] for p in prompt_embs])
136
-
137
- inputs_embeds = torch.cat(
138
- [
139
- torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
140
- for p in prompt_embs
141
- ],
142
- dim=0,
143
- )
144
- attention_mask = torch.cat(
145
- [
146
- torch.cat(
147
- [
148
- torch.zeros(
149
- 1,
150
- max_len - p.shape[0],
151
- device=self.device,
152
- dtype=torch.long,
153
- ),
154
- torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
155
- ],
156
- dim=1,
157
- )
158
- for p in prompt_embs
159
- ],
160
- dim=0,
161
- )
162
-
163
- generate_config = {
164
- "eos_token_id": tokenizer.eos_token_id,
165
- "bos_token_id": tokenizer.bos_token_id,
166
- "pad_token_id": tokenizer.bos_token_id,
167
- "max_new_tokens": 512,
168
- **kwargs,
169
- }
170
-
171
- with torch.no_grad():
172
- output_ids = self.text_model.generate(
173
- inputs_embeds=inputs_embeds,
174
- attention_mask=attention_mask,
175
- **generate_config,
176
- )
177
-
178
- return [
179
- x.strip()
180
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
181
- ]
182
- """
183
  import torch
184
  from .vision_encoder import VisionEncoder
185
  from .configuration_moondream import MoondreamConfig
@@ -189,7 +7,7 @@ from .modeling_phi import PhiForCausalLM
189
  from .configuration_moondream import PhiConfig
190
 
191
  from threading import Thread
192
- from transformers import TextIteratorStreamer
193
 
194
 
195
  class Moondream(PreTrainedModel):
@@ -266,7 +84,7 @@ class Moondream(PreTrainedModel):
266
 
267
  with torch.no_grad():
268
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
269
- streamer = TextIteratorStreamer(tokenizer)
270
 
271
  # Start generation in a separate thread
272
  thread = Thread(target=self.text_model.generate, kwargs={
@@ -278,6 +96,7 @@ class Moondream(PreTrainedModel):
278
 
279
  # Yield generated text as it becomes available
280
  for new_text in streamer:
 
281
  yield new_text
282
 
283
  thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
 
7
  from .configuration_moondream import PhiConfig
8
 
9
  from threading import Thread
10
+
11
 
12
 
13
  class Moondream(PreTrainedModel):
 
84
 
85
  with torch.no_grad():
86
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
87
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
88
 
89
  # Start generation in a separate thread
90
  thread = Thread(target=self.text_model.generate, kwargs={
 
96
 
97
  # Yield generated text as it becomes available
98
  for new_text in streamer:
99
+ print("NEW TEXT" + new_text)
100
  yield new_text
101
 
102
  thread.join()