Sana1207 commited on
Commit
e38a780
1 Parent(s): 5522918

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -92
app.py CHANGED
@@ -1,100 +1,119 @@
 
1
  import streamlit as st
2
- import base64
3
- from byaldi import RAGMultiModalModel
4
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
5
  from PIL import Image
6
- from io import BytesIO
7
  import torch
 
 
 
8
  import re
9
 
10
  @st.cache_resource
11
- def load_models():
12
- RAG = RAGMultiModalModel.from_pretrained("vidore/colpali", verbose=10)
13
- model = Qwen2VLForConditionalGeneration.from_pretrained(
14
- "Qwen/Qwen2-VL-2B-Instruct",
15
- torch_dtype=torch.float16,
16
- device_map="auto",
17
- )
 
 
 
 
 
 
 
18
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
19
- return RAG, model, processor
20
-
21
- RAG, model, processor = load_models()
22
-
23
- st.title("Multimodal Image Search and Text Extraction App")
24
-
25
- uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
26
-
27
- if uploaded_file is not None:
28
- image = Image.open(uploaded_file)
29
- st.image(image, caption='Uploaded Image', use_column_width=True)
30
-
31
- temp_image_path = "uploaded_image.jpeg"
32
- image.save(temp_image_path)
33
-
34
- @st.cache_data
35
- def create_rag_index(image_path):
36
- RAG.index(
37
- input_path=image_path,
38
- index_name="image_index",
39
- store_collection_with_index=True,
40
- overwrite=True,
41
- )
42
-
43
- create_rag_index(temp_image_path)
44
-
45
- text_query = st.text_input("Enter your text query")
46
-
47
- if st.button("Search and Extract Text"):
48
- if text_query:
49
- results = RAG.search(text_query, k=1, return_base64_results=True)
50
-
51
- image_data = base64.b64decode(results[0].base64)
52
- image = Image.open(BytesIO(image_data))
53
- st.image(image, caption="Result Image", use_column_width=True)
54
-
55
- messages = [
56
- {
57
- "role": "user",
58
- "content": [
59
- {"type": "image"},
60
- {"type": "text", "text": "Run OCR on the image"}
61
- ]
62
- }
63
- ]
64
-
65
- text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
66
-
67
- inputs = processor(
68
- text=[text_prompt],
69
- images=[image],
70
- padding=True,
71
- return_tensors="pt"
72
- )
73
-
74
- inputs = inputs.to(model.device)
75
-
76
- with torch.no_grad():
77
- output_ids = model.generate(**inputs, max_new_tokens=1024)
78
-
79
- generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
80
-
81
- output_text = processor.batch_decode(
82
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
83
- )[0]
84
-
85
- # Highlight the queried text
86
- def highlight_text(text, query):
87
- highlighted_text = text
88
- for word in query.split():
89
- pattern = re.compile(re.escape(word), re.IGNORECASE)
90
- highlighted_text = pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', highlighted_text)
91
- return highlighted_text
92
-
93
- highlighted_output = highlight_text(output_text, text_query)
94
-
95
- st.subheader("Extracted Text (with query highlighted):")
96
- st.markdown(highlighted_output, unsafe_allow_html=True)
97
  else:
98
- st.warning("Please enter a query.")
99
- else:
100
- st.info("Upload an image to get started.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
2
  import streamlit as st
3
+ import os
 
 
4
  from PIL import Image
5
+ import requests
6
  import torch
7
+ import json
8
+ from torchvision import io
9
+ from typing import Dict
10
  import re
11
 
12
  @st.cache_resource
13
+ def init_model():
14
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
15
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
16
+ model = model.eval()
17
+ return model, tokenizer
18
+
19
+ def init_gpu_model():
20
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
21
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
22
+ model = model.eval().cuda()
23
+ return model, tokenizer
24
+
25
+ def init_qwen_model():
26
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
27
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
28
+ return model, processor
29
+
30
+ def get_quen_op(image_file, model, processor):
31
+ try:
32
+ image = Image.open(image_file).convert('RGB')
33
+ conversation = [
34
+ {
35
+ "role":"user",
36
+ "content":[
37
+ {
38
+ "type":"image",
39
+ },
40
+ {
41
+ "type":"text",
42
+ "text":"Extract text from this image."
43
+ }
44
+ ]
45
+ }
46
+ ]
47
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
48
+ inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
49
+ inputs = {k: v.to(torch.float32) if torch.is_floating_point(v) else v for k, v in inputs.items()}
50
+
51
+ generation_config = {
52
+ "max_new_tokens": 32,
53
+ "do_sample": False,
54
+ "top_k": 20,
55
+ "top_p": 0.90,
56
+ "temperature": 0.4,
57
+ "num_return_sequences": 1,
58
+ "pad_token_id": processor.tokenizer.pad_token_id,
59
+ "eos_token_id": processor.tokenizer.eos_token_id,
60
+ }
61
+
62
+ output_ids = model.generate(**inputs, **generation_config)
63
+ if 'input_ids' in inputs:
64
+ generated_ids = output_ids[:, inputs['input_ids'].shape[1]:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  else:
66
+ generated_ids = output_ids
67
+
68
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
69
+
70
+ return output_text[:] if output_text else "No text extracted from the image."
71
+
72
+ except Exception as e:
73
+ return f"An error occurred: {str(e)}"
74
+
75
+ @st.cache_data
76
+ def get_text(image_file, _model, _tokenizer):
77
+ res = _model.chat(_tokenizer, image_file, ocr_type='ocr')
78
+ return res
79
+
80
+ def highlight_text(text, search_term):
81
+ if not search_term:
82
+ return text
83
+ pattern = re.compile(re.escape(search_term), re.IGNORECASE)
84
+ return pattern.sub(lambda m: f'<span style="background-color: grey;">{m.group()}</span>', text)
85
+
86
+ def save_text_to_json(file_name, text_data):
87
+ """Save the extracted text into a JSON file."""
88
+ with open(file_name, 'w') as json_file:
89
+ json.dump({"extracted_text": text_data}, json_file, indent=4)
90
+ st.success(f"Text saved to {file_name}")
91
+
92
+ st.title("Extract text from the image using - GOT-OCR2.0 and search keyword")
93
+ st.write("Upload an image")
94
+
95
+ MODEL, PROCESSOR = init_model()
96
+
97
+ image_file = st.file_uploader("Upload Image", type=['jpg', 'png', 'jpeg'])
98
+
99
+ if image_file:
100
+ if not os.path.exists("images"):
101
+ os.makedirs("images")
102
+ with open(f"images/{image_file.name}", "wb") as f:
103
+ f.write(image_file.getbuffer())
104
+
105
+ image_file = f"images/{image_file.name}"
106
+
107
+ text = get_text(image_file, MODEL, PROCESSOR)
108
+
109
+ print(text)
110
+
111
+ # Add search functionality
112
+ search_term = st.text_input("Enter a word or phrase to search:")
113
+ highlighted_text = highlight_text(text, search_term)
114
+
115
+ st.markdown(highlighted_text, unsafe_allow_html=True)
116
+
117
+ # Save the extracted text in JSON
118
+ json_file_path = f"{image_file}_extracted.json"
119
+ save_text_to_json(json_file_path, text)