sprakhil commited on
Commit
795b781
·
1 Parent(s): 45ab2ce

resolving pipeline issue

Browse files
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -1,36 +1,45 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline
5
  from colpali_engine.models import ColPali, ColPaliProcessor
6
  import os
7
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
 
10
  hf_token = os.getenv('HF_TOKEN')
 
 
11
  try:
12
- model = pipeline("image-to-text", model="google/paligemma-3b-mix-448", use_auth_token=hf_token)
 
13
  except Exception as e:
14
  st.error(f"Error loading image-to-text model: {e}")
15
  st.stop()
16
 
 
17
  try:
18
- model_colpali = ColPali.from_pretrained("vidore/colpali-v1.2", torch_dtype=torch.bfloat16).to(device)
19
- processor_colpali = ColPaliProcessor.from_pretrained("google/paligemma-3b-mix-448")
20
  except Exception as e:
21
  st.error(f"Error loading ColPali model or processor: {e}")
22
  st.stop()
23
 
 
24
  try:
25
- model_qwen = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct").to(device)
26
- processor_qwen = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
27
  except Exception as e:
28
  st.error(f"Error loading Qwen model or processor: {e}")
29
  st.stop()
30
 
 
31
  st.title("OCR and Document Search Web Application")
32
  st.write("Upload an image containing text in both Hindi and English for OCR processing and keyword search.")
33
 
 
34
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
35
 
36
  if uploaded_file is not None:
@@ -39,21 +48,33 @@ if uploaded_file is not None:
39
  st.image(image, caption='Uploaded Image.', use_column_width=True)
40
  st.write("")
41
 
 
 
 
 
 
 
 
 
 
 
42
  conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe this image."}]}]
43
  text_prompt = processor_qwen.apply_chat_template(conversation, add_generation_prompt=True)
44
- inputs = processor_qwen(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device)
45
 
 
46
  with torch.no_grad():
47
- output_ids = model_qwen.generate(**inputs, max_new_tokens=128)
48
- generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
49
- output_text = processor_qwen.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
 
51
- st.write("Extracted Text:")
52
- st.write(output_text)
53
 
 
54
  keyword = st.text_input("Enter a keyword to search in the extracted text:")
55
  if keyword:
56
- if keyword.lower() in output_text[0].lower():
57
  st.write(f"Keyword '{keyword}' found in the text.")
58
  else:
59
  st.write(f"Keyword '{keyword}' not found in the text.")
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForImageToText
5
  from colpali_engine.models import ColPali, ColPaliProcessor
6
  import os
7
 
8
+ # Set device for computation
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # Get Hugging Face token from environment variables
12
  hf_token = os.getenv('HF_TOKEN')
13
+
14
+ # Load the processor and image-to-text model directly using Hugging Face token
15
  try:
16
+ processor_img_to_text = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448", use_auth_token=hf_token)
17
+ model_img_to_text = AutoModelForImageToText.from_pretrained("google/paligemma-3b-mix-448", use_auth_token=hf_token).to(device)
18
  except Exception as e:
19
  st.error(f"Error loading image-to-text model: {e}")
20
  st.stop()
21
 
22
+ # Load ColPali model with Hugging Face token
23
  try:
24
+ model_colpali = ColPali.from_pretrained("vidore/colpali-v1.2", torch_dtype=torch.bfloat16, use_auth_token=hf_token).to(device)
25
+ processor_colpali = ColPaliProcessor.from_pretrained("google/paligemma-3b-mix-448", use_auth_token=hf_token)
26
  except Exception as e:
27
  st.error(f"Error loading ColPali model or processor: {e}")
28
  st.stop()
29
 
30
+ # Load Qwen model with Hugging Face token
31
  try:
32
+ model_qwen = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", use_auth_token=hf_token).to(device)
33
+ processor_qwen = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", use_auth_token=hf_token)
34
  except Exception as e:
35
  st.error(f"Error loading Qwen model or processor: {e}")
36
  st.stop()
37
 
38
+ # Streamlit UI
39
  st.title("OCR and Document Search Web Application")
40
  st.write("Upload an image containing text in both Hindi and English for OCR processing and keyword search.")
41
 
42
+ # File uploader for the image
43
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
44
 
45
  if uploaded_file is not None:
 
48
  st.image(image, caption='Uploaded Image.', use_column_width=True)
49
  st.write("")
50
 
51
+ # Use the image-to-text model to extract text from the image
52
+ inputs_img_to_text = processor_img_to_text(images=image, return_tensors="pt").to(device)
53
+ with torch.no_grad():
54
+ generated_ids_img_to_text = model_img_to_text.generate(**inputs_img_to_text, max_new_tokens=128)
55
+ output_text_img_to_text = processor_img_to_text.batch_decode(generated_ids_img_to_text, skip_special_tokens=True, clean_up_tokenization_spaces=True)
56
+
57
+ st.write("Extracted Text from Image:")
58
+ st.write(output_text_img_to_text)
59
+
60
+ # Prepare input for Qwen model for image description
61
  conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe this image."}]}]
62
  text_prompt = processor_qwen.apply_chat_template(conversation, add_generation_prompt=True)
63
+ inputs_qwen = processor_qwen(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device)
64
 
65
+ # Generate response with Qwen model
66
  with torch.no_grad():
67
+ output_ids_qwen = model_qwen.generate(**inputs_qwen, max_new_tokens=128)
68
+ generated_ids_qwen = [output_ids_qwen[len(input_ids):] for input_ids, output_ids_qwen in zip(inputs_qwen.input_ids, output_ids_qwen)]
69
+ output_text_qwen = processor_qwen.batch_decode(generated_ids_qwen, skip_special_tokens=True, clean_up_tokenization_spaces=True)
70
 
71
+ st.write("Qwen Model Description:")
72
+ st.write(output_text_qwen)
73
 
74
+ # Keyword search in the extracted text
75
  keyword = st.text_input("Enter a keyword to search in the extracted text:")
76
  if keyword:
77
+ if keyword.lower() in output_text_img_to_text[0].lower():
78
  st.write(f"Keyword '{keyword}' found in the text.")
79
  else:
80
  st.write(f"Keyword '{keyword}' not found in the text.")