File size: 9,646 Bytes
81dbae6
8de60e6
 
 
 
 
 
 
 
81dbae6
 
 
8de60e6
 
 
 
 
 
 
 
81dbae6
8de60e6
 
81dbae6
 
8de60e6
 
81dbae6
8de60e6
 
81dbae6
 
 
 
 
 
8de60e6
 
81dbae6
 
8de60e6
81dbae6
 
8de60e6
 
81dbae6
 
 
8de60e6
 
81dbae6
 
8de60e6
81dbae6
 
 
8de60e6
 
 
81dbae6
 
8de60e6
 
 
81dbae6
8de60e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81dbae6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import torch
import onnxruntime as ort
from PIL import Image
import requests
import numpy as np
from transformers import AutoTokenizer, AutoProcessor
import os

os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/decoder_model_merged_q4f16.onnx')
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/embed_tokens_q4f16.onnx')
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/vision_encoder_q4f16.onnx')
# Load the tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")

vision_encoder_session = ort.InferenceSession("vision_encoder_q4f16.onnx")
decoder_session = ort.InferenceSession("decoder_model_merged_q4f16.onnx")
embed_tokens_session = ort.InferenceSession("embed_tokens_q4f16.onnx")

def merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask,pad_token_id,special_image_token_id):
    num_images, num_image_patches, embed_dim = image_features.shape
    batch_size, sequence_length = input_ids.shape
    left_padding = not np.sum(input_ids[:, -1] == pad_token_id)
    # 1. Create a mask to know where special image tokens are
    special_image_token_mask = input_ids == special_image_token_id
    num_special_image_tokens = np.sum(special_image_token_mask, axis=-1)
    # Compute the maximum embed dimension
    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
    batch_indices, non_image_indices = np.where(input_ids != special_image_token_id)

    # 2. Compute the positions where text should be written
    # Calculate new positions for text tokens in merged image-text sequence.
    # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
    # `np.cumsum` computes how each image token shifts subsequent text token positions.
    # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
    new_token_positions = np.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    if left_padding:
        new_token_positions += nb_image_pad[:, None]  # offset for left padding
    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

    # 3. Create the full embedding, already padded to the maximum position
    final_embedding = np.zeros((batch_size, max_embed_dim, embed_dim), dtype=np.float32)
    final_attention_mask = np.zeros((batch_size, max_embed_dim), dtype=np.int64)

    # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
    # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
    final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
    final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
    # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
    image_to_overwrite = np.full((batch_size, max_embed_dim), True)
    image_to_overwrite[batch_indices, text_to_overwrite] = False
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

    final_embedding[image_to_overwrite] = image_features.reshape(-1, embed_dim)
    final_attention_mask = np.logical_or(final_attention_mask, image_to_overwrite).astype(final_attention_mask.dtype)
    position_ids = final_attention_mask.cumsum(axis=-1) - 1
    position_ids = np.where(final_attention_mask == 0, 1, position_ids)

    # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
    batch_indices, pad_indices = np.where(input_ids == pad_token_id)
    indices_to_mask = new_token_positions[batch_indices, pad_indices]
    final_embedding[batch_indices, indices_to_mask] = 0

    return final_embedding, final_attention_mask, position_ids

# Load model and processor

def describe_image(image):
    if(image.mode != 'RGB'):
      image = image.convert('RGB')
    conversation = [
        {
            "role": "system",
            "content": "You are a helpful assistant who describes image."
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Describe this image in about 200 words and explain each and every element in full detail"},
                {"type": "image"},
            ],
        },
    ]

    # Apply chat template
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    # Preprocess the image and text
    inputs = processor(images=image, text=prompt, return_tensors="np")
    vision_input_name = vision_encoder_session.get_inputs()[0].name
    vision_output_name = vision_encoder_session.get_outputs()[0].name
    vision_features = vision_encoder_session.run([vision_output_name], {vision_input_name: inputs["pixel_values"]})[0]

    # print('Total Time for Image Features Making ', time.time() - start)

    # Tokens for the prompt
    input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]

    # Prepare inputs
    sequence_length = input_ids.shape[1]
    batch_size = 1
    num_layers = 24
    head_dim = 64
    num_heads = 16
    pad_token_id = tokenizer.pad_token_id
    past_sequence_length = 0  # Set to 0 for the initial pass
    special_image_token_id = 151646

    # Position IDs
    position_ids = np.arange(sequence_length, dtype=np.int64).reshape(1, -1)

    # Past Key Values
    past_key_values = {
        f"past_key_values.{i}.key": np.zeros((batch_size, num_heads, past_sequence_length, head_dim), dtype=np.float32)
        for i in range(num_layers)
    }
    past_key_values.update({
        f"past_key_values.{i}.value": np.zeros((batch_size, num_heads, past_sequence_length, head_dim), dtype=np.float32)
        for i in range(num_layers)
    })

    # Run embed tokens
    embed_input_name = embed_tokens_session.get_inputs()[0].name
    embed_output_name = embed_tokens_session.get_outputs()[0].name
    token_embeddings = embed_tokens_session.run([embed_output_name], {embed_input_name: input_ids})[0]

    # Combine token embeddings and vision features
    combined_embeddings, attention_mask, position_ids = merge_input_ids_with_image_features(vision_features, token_embeddings, input_ids, attention_mask,pad_token_id,special_image_token_id)
    combined_len = combined_embeddings.shape[1]

    # Combine all inputs
    decoder_inputs = {
        "attention_mask": attention_mask,
        "position_ids": position_ids,
        "inputs_embeds": combined_embeddings,
        **past_key_values
    }

    # Print input shapes
    for name, value in decoder_inputs.items():
        print(f"{name} shape: {value.shape} dtype {value.dtype}")

    # Run the decoder
    decoder_input_names = [input.name for input in decoder_session.get_inputs()]
    decoder_output_name = decoder_session.get_outputs()[0].name
    names = [n.name for n in decoder_session.get_outputs()]
    outputs = decoder_session.run(names, {name: decoder_inputs[name] for name in decoder_input_names if name in decoder_inputs})

    # ... (previous code remains the same until after the decoder run)
    # print(f"Outputs shape: {outputs[0].shape}")
    # print(f"Outputs type: {outputs[0].dtype}")

    # Process outputs (decode tokens to text)
    generated_tokens = []
    eos_token_id = tokenizer.eos_token_id
    max_new_tokens = 2048

    for i in range(max_new_tokens):
        logits = outputs[0]
        past_kv = outputs[1:]
        logits_next_token = logits[:, -1]
        token_id = np.argmax(logits_next_token)

        if token_id == eos_token_id:
            break

        generated_tokens.append(token_id)

        # Prepare input for next token generation
        new_input_embeds = embed_tokens_session.run([embed_output_name], {embed_input_name: np.array([[token_id]])})[0]

        past_key_values = {name.replace("present", "past_key_values"): value for name, value in zip(names[1:], outputs[1:])}

        attention_mask = np.ones((1, combined_len + i + 1), dtype=np.int64)
        position_ids = np.arange(combined_len + i + 1, dtype=np.int64).reshape(1, -1)[:, -1:]

        decoder_inputs = {
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "inputs_embeds": new_input_embeds,
            **past_key_values
        }

        outputs = decoder_session.run(names, {name: decoder_inputs[name] for name in decoder_input_names if name in decoder_inputs})

    # Convert to list of integers
    token_ids = [int(token) for token in generated_tokens]

    print(f"Generated token IDs: {token_ids}")

    # Decode tokens one by one
    decoded_tokens = [tokenizer.decode([token]) for token in token_ids]
    print(f"Decoded tokens: {decoded_tokens}")

    # Full decoded output
    decoded_output = tokenizer.decode(token_ids, skip_special_tokens=True)
    return decoded_output

# Create Gradio interface
interface = gr.Interface(
    fn=describe_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(lines=5, placeholder="Description will appear here"),
    title="Image Description Generator",
    description="Upload an image to get a detailed description."
)

# Enable API
interface.launch(share=True,show_error=True,debug=True)