File size: 5,816 Bytes
7901fac
 
bd52e59
7901fac
5781f7f
 
 
 
 
 
7901fac
5781f7f
 
7901fac
5781f7f
7901fac
 
5781f7f
7901fac
5781f7f
 
7901fac
bd52e59
 
7901fac
5781f7f
7901fac
5781f7f
7901fac
5781f7f
 
 
 
 
 
 
7901fac
5781f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
7901fac
5781f7f
 
 
 
 
 
 
 
 
7901fac
 
 
 
5781f7f
 
 
 
7901fac
 
 
 
 
 
 
5781f7f
7901fac
 
 
5781f7f
7901fac
 
 
 
 
 
5781f7f
 
7901fac
 
 
5781f7f
7901fac
 
5781f7f
 
 
 
 
 
 
 
7901fac
5781f7f
 
 
 
 
7901fac
5781f7f
 
7901fac
5781f7f
 
 
 
7901fac
5781f7f
 
 
7901fac
5781f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from byaldi import RAGMultiModalModel
from PIL import Image
import io
import time
import nltk
from nltk.translate.bleu_score import sentence_bleu

# Download NLTK data for BLEU score calculation
nltk.download('punkt', quiet=True)

# Load models and processors
@st.cache_resource
def load_models():
    RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
    
    qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map="auto",
        trust_remote_code=True
    ).cuda().eval()
    
    qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
    
    return RAG, qwen_model, qwen_processor

RAG, qwen_model, qwen_processor = load_models()

# Function to get current CUDA memory usage
def get_cuda_memory_usage():
    return torch.cuda.memory_allocated() / 1024**2  # Convert to MB

# Define processing functions
def extract_text_with_colpali(image):
    start_time = time.time()
    start_memory = get_cuda_memory_usage()
    
    extracted_text = RAG.extract_text(image)
    
    end_time = time.time()
    end_memory = get_cuda_memory_usage()
    
    return extracted_text, {
        'time': end_time - start_time,
        'memory': end_memory - start_memory
    }

def process_with_qwen(query, extracted_text, image, extract_mode=False):
    start_time = time.time()
    start_memory = get_cuda_memory_usage()
    
    if extract_mode:
        instruction = "Extract and list all text visible in this image, including both printed and handwritten text."
    else:
        instruction = f"Context: {extracted_text}\n\nQuery: {query}"
    
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": instruction
                },
                {
                    "type": "image",
                    "image": image,
                },
            ],
        }
    ]
    text = qwen_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = qwen_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    generated_ids = qwen_model.generate(**inputs, max_new_tokens=200)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = qwen_processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    end_time = time.time()
    end_memory = get_cuda_memory_usage()
    
    return output_text[0], {
        'time': end_time - start_time,
        'memory': end_memory - start_memory
    }

# Function to calculate BLEU score
def calculate_bleu(reference, hypothesis):
    reference_tokens = nltk.word_tokenize(reference.lower())
    hypothesis_tokens = nltk.word_tokenize(hypothesis.lower())
    return sentence_bleu([reference_tokens], hypothesis_tokens)

# Streamlit UI
st.title("Document Processing with ColPali and Qwen")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
query = st.text_input("Enter your query:")

if uploaded_file is not None and query:
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    if st.button("Process"):
        with st.spinner("Processing..."):
            # Extract text using ColPali
            colpali_extracted_text, colpali_metrics = extract_text_with_colpali(image)
            
            # Extract text using Qwen
            qwen_extracted_text, qwen_extract_metrics = process_with_qwen("", "", image, extract_mode=True)
            
            # Process the query with Qwen2, using both extracted text and image
            qwen_response, qwen_response_metrics = process_with_qwen(query, colpali_extracted_text, image)
            
            # Calculate BLEU score between ColPali and Qwen extractions
            bleu_score = calculate_bleu(colpali_extracted_text, qwen_extracted_text)

        # Display results
        st.subheader("Results")
        st.write("ColPali Extracted Text:")
        st.write(colpali_extracted_text)
        
        st.write("Qwen Extracted Text:")
        st.write(qwen_extracted_text)
        
        st.write("Qwen Response:")
        st.write(qwen_response)

        # Display metrics
        st.subheader("Metrics")
        
        st.write("ColPali Extraction:")
        st.write(f"Time: {colpali_metrics['time']:.2f} seconds")
        st.write(f"Memory: {colpali_metrics['memory']:.2f} MB")
        
        st.write("Qwen Extraction:")
        st.write(f"Time: {qwen_extract_metrics['time']:.2f} seconds")
        st.write(f"Memory: {qwen_extract_metrics['memory']:.2f} MB")
        
        st.write("Qwen Response:")
        st.write(f"Time: {qwen_response_metrics['time']:.2f} seconds")
        st.write(f"Memory: {qwen_response_metrics['memory']:.2f} MB")
        
        st.write(f"BLEU Score: {bleu_score:.4f}")

st.markdown("""
## How to Use

1. Upload an image containing text or a document.
2. Enter your query about the document.
3. Click 'Process' to see the results.

The app will display:
- Text extracted by ColPali
- Text extracted by Qwen
- Qwen's response to your query
- Performance metrics for each step
- BLEU score comparing ColPali and Qwen extractions
""")