zliang commited on
Commit
24d96d7
1 Parent(s): 0ce080b

Update pdfextract_fun.py

Browse files
Files changed (1) hide show
  1. pdfextract_fun.py +172 -61
pdfextract_fun.py CHANGED
@@ -1,82 +1,193 @@
1
- import os
2
- import re
3
  import warnings
 
 
 
 
 
 
4
  import cv2
 
5
  import fitz # PyMuPDF
6
  import numpy as np
 
7
  import pytesseract
8
  import torch
9
  from PIL import Image
10
- from tqdm import tqdm
 
 
 
 
11
  from detectron2.config import get_cfg
 
12
  from detectron2.data import MetadataCatalog
13
  from detectron2.engine import DefaultPredictor
14
- from detectron2.utils.visualizer import ColorMode, Visualizer
15
- from unilm.dit.object_detection.ditod import add_vit_config
16
 
17
- # Filter specific warnings
18
- warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
19
- warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
20
 
21
- # Configuration setup
22
- def setup_config():
23
- cfg = get_cfg()
24
- add_vit_config(cfg)
25
- cfg.merge_from_file("cascade_dit_base.yml")
26
- cfg.MODEL.WEIGHTS = "publaynet_dit-b_cascade.pth"
27
- cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
- return cfg
29
-
30
- # Analyze image
31
- def analyze_image(img, cfg):
32
- """Analyze an image and return the result image, output, and visualizer."""
33
- md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
34
- thing_classes = ["table"] if cfg.DATASETS.TEST[0] == 'icdar2019_test' else ["text", "title", "list", "table", "figure"]
35
- md.set(thing_classes=thing_classes)
36
 
37
- output = DefaultPredictor(cfg)(img)["instances"]
38
- v = Visualizer(img[:, :, ::-1], metadata=md, scale=1.0, instance_mode=ColorMode.SEGMENTATION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  result = v.draw_instance_predictions(output.to("cpu"))
40
- return result.get_image()[:, :, ::-1], output, v
 
 
 
 
41
 
42
- # PDF to JPEG conversion
43
  def convert_pdf_to_jpg(pdf_path, output_folder, zoom_factor=2):
44
- """Convert PDF file to JPEG images, saved in the specified output folder."""
45
  doc = fitz.open(pdf_path)
46
- for page_num, page in enumerate(doc):
47
- mat = fitz.Matrix(zoom_factor, zoom_factor)
48
- pix = page.get_pixmap(matrix=mat)
49
- output_file = os.path.join(output_folder, f"page_{page_num}.jpg")
 
 
 
 
50
  pix.save(output_file)
51
 
52
- # Process JPEG images in a folder
53
- def process_jpeg_images(output_folder, cfg):
54
- """Process each JPEG image in the output folder."""
55
- for page_num in tqdm(range(len(os.listdir(output_folder))), desc="Processing the pdf"):
56
- file_path = os.path.join(output_folder, f"page_{page_num}.jpg")
57
- img = cv2.imread(file_path)
58
- if img is None:
59
- print(f"Failed to read {file_path}. Skipping.")
60
- continue
61
- result_image, output, v = analyze_image(img, cfg)
62
- save_extracted_instances(img, output, page_num, output_folder)
63
-
64
- # Save extracted instances
 
 
 
65
  def save_extracted_instances(img, output, page_num, dest_folder, confidence_threshold=0.8):
66
- """Save instances extracted from an image to the destination folder."""
67
- class_names = {0: "text", 1: "title", 2: "list", 3: "table", 4: "figure"}
 
 
 
 
 
 
 
 
 
68
  instances = output.to("cpu")
69
- for i, (box, class_id, score) in enumerate(zip(instances.pred_boxes.tensor.numpy(), instances.pred_classes.tolist(), instances.scores.tolist())):
70
- if score >= confidence_threshold and class_names.get(class_id) in ["figure", "table", "text"]:
71
- x1, y1, x2, y2 = map(int, box)
72
- cropped_image = img[y1:y2, x1:x2]
73
- if np.std(cropped_image) > 0 and (y2 - y1) > 0: # Replace with actual thresholds if needed
74
- save_path = os.path.join(dest_folder, f"page_{page_num}_{class_names[class_id]}_{i + 1}.jpg")
75
- cv2.imwrite(save_path, cropped_image)
76
-
77
- # Additional functions like delete_files_in_folder, rename_files_sequentially, ocr_folder, and ocr_image can be included as is, assuming they were satisfactory before.
78
-
79
- cfg = setup_config()
80
- # Example usage
81
- convert_pdf_to_jpg("sample.pdf", "output_folder")
82
- process_jpeg_images("output_folder", cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import warnings
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ import time
4
+ # Filter warnings about inputs not requiring gradients
5
+ warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
6
+ warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
7
+
8
  import cv2
9
+ import os
10
  import fitz # PyMuPDF
11
  import numpy as np
12
+ import re
13
  import pytesseract
14
  import torch
15
  from PIL import Image
16
+ from tqdm import tqdm
17
+
18
+ from unilm.dit.object_detection.ditod import add_vit_config
19
+
20
+ from detectron2.config import CfgNode as CN
21
  from detectron2.config import get_cfg
22
+ from detectron2.utils.visualizer import ColorMode, Visualizer
23
  from detectron2.data import MetadataCatalog
24
  from detectron2.engine import DefaultPredictor
 
 
25
 
 
 
 
26
 
27
+ # Step 1: instantiate config
28
+ cfg = get_cfg()
29
+ add_vit_config(cfg)
30
+ cfg.merge_from_file("cascade_dit_base.yml")
31
+
32
+ # Step 2: add model weights URL to config
33
+ cfg.MODEL.WEIGHTS = "publaynet_dit-b_cascade.pth"
34
+
35
+ # Step 3: set device
36
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ #cfg.MODEL.DEVICE = "cuda"
 
 
 
 
38
 
39
+ # Step 4: define model
40
+ predictor = DefaultPredictor(cfg)
41
+
42
+ def analyze_image(img):
43
+
44
+ md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
45
+ if cfg.DATASETS.TEST[0]=='icdar2019_test':
46
+ md.set(thing_classes=["table"])
47
+ else:
48
+ md.set(thing_classes=["text","title","list","table","figure"])
49
+
50
+ output = predictor(img)["instances"]
51
+ v = Visualizer(img[:, :, ::-1],
52
+ md,
53
+ scale=1.0,
54
+ instance_mode=ColorMode.SEGMENTATION)
55
  result = v.draw_instance_predictions(output.to("cpu"))
56
+ result_image = result.get_image()[:, :, ::-1]
57
+
58
+ return result_image, output, v
59
+
60
+
61
 
 
62
  def convert_pdf_to_jpg(pdf_path, output_folder, zoom_factor=2):
 
63
  doc = fitz.open(pdf_path)
64
+ for page_num in range(len(doc)):
65
+ page = doc.load_page(page_num)
66
+
67
+ # Adjust zoom factor for higher resolution
68
+ mat = fitz.Matrix(zoom_factor, zoom_factor) # Create a Matrix with the zoom factor
69
+ pix = page.get_pixmap(matrix=mat) # Render the page using the matrix
70
+
71
+ output_file = f"{output_folder}/page_{page_num}.jpg"
72
  pix.save(output_file)
73
 
74
+
75
+
76
+ def process_jpeg_images(output_folder):
77
+ for page_num in tqdm(range(len(os.listdir(output_folder))), desc="Processing the pdf"):
78
+ file_path = f"{output_folder}/page_{page_num}.jpg"
79
+ img = cv2.imread(file_path)
80
+ if img is None:
81
+ print(f"Failed to read {file_path}. Skipping.")
82
+ continue
83
+ result_image, output, v = analyze_image(img)
84
+
85
+ # Saving logic
86
+ save_extracted_instances(img, output, page_num,output_folder)
87
+
88
+
89
+
90
  def save_extracted_instances(img, output, page_num, dest_folder, confidence_threshold=0.8):
91
+ class_names = {
92
+ 0: "text",
93
+ 1: "title",
94
+ 2: "list",
95
+ 3: "table",
96
+ 4: "figure"
97
+ }
98
+
99
+ threshold_value = 0 # Standard deviation threshold
100
+ min_height = 0 # Minimum height threshold
101
+
102
  instances = output.to("cpu")
103
+ boxes = instances.pred_boxes.tensor.numpy()
104
+ class_ids = instances.pred_classes.tolist()
105
+ scores = instances.scores.tolist() # Get prediction scores
106
+
107
+ image_counter = 1
108
+ for box, class_id, score in zip(boxes, class_ids, scores):
109
+ # Check if the prediction score meets the confidence threshold
110
+ if score >= confidence_threshold:
111
+ class_name = class_names.get(class_id, "unknown")
112
+
113
+ # Save only if class is 'figure' or 'table'
114
+ if class_name in ["figure", "table","text"]:
115
+ x1, y1, x2, y2 = map(int, box)
116
+ cropped_image = img[y1:y2, x1:x2]
117
+
118
+ if np.std(cropped_image) > threshold_value and (y2 - y1) > min_height:
119
+ save_path = os.path.join(dest_folder, f"page_{page_num}_{class_name}_{image_counter}.jpg")
120
+ cv2.imwrite(save_path, cropped_image)
121
+ image_counter += 1
122
+
123
+
124
+ def delete_files_in_folder(folder_path):
125
+ for filename in os.listdir(folder_path):
126
+ file_path = os.path.join(folder_path, filename)
127
+ if os.path.isfile(file_path):
128
+ os.remove(file_path)
129
+
130
+
131
+
132
+ def rename_files_sequentially(folder_path):
133
+ # Regex pattern to match 'page_{page_num}_{class_name}_{image_counter}.jpg'
134
+ pattern = re.compile(r'page_(\d+)_(\w+)_(\d+).jpg', re.IGNORECASE)
135
+
136
+ # List files in the folder
137
+ files = os.listdir(folder_path)
138
+
139
+ # Filter and sort files based on the regex pattern
140
+ sorted_files = sorted(
141
+ [f for f in files if pattern.match(f)],
142
+ key=lambda x: (int(pattern.match(x).group(1)), pattern.match(x).group(2).lower(), int(pattern.match(x).group(3)))
143
+ )
144
+
145
+ # Initialize an empty dictionary for counters
146
+ counters = {}
147
+
148
+ for filename in sorted_files:
149
+ match = pattern.match(filename)
150
+ if match:
151
+ page_num, class_name, _ = match.groups()
152
+ class_name = class_name.lower() # Convert class name to lowercase
153
+
154
+ # Initialize counter for this class if it doesn't exist
155
+ if class_name not in counters:
156
+ counters[class_name] = 1
157
+
158
+ # New filename format: '{class_name}_{sequential_number}.jpg'
159
+ new_filename = f"{class_name}_{counters[class_name]}.jpg"
160
+ counters[class_name] += 1
161
+
162
+ # Rename the file
163
+ os.rename(os.path.join(folder_path, filename), os.path.join(folder_path, new_filename))
164
+
165
+ #print(f"Renamed '{filename}' to '{new_filename}'")
166
+
167
+
168
+ def ocr_folder(folder_path):
169
+ # Regex pattern to match 'text_{number}.jpg'
170
+ pattern = re.compile(r'text_\d+\.jpg', re.IGNORECASE)
171
+
172
+ # Create a subfolder for the OCR text files
173
+ ocr_text_folder = os.path.join(folder_path, "ocr_results")
174
+ if not os.path.exists(ocr_text_folder):
175
+ os.makedirs(ocr_text_folder)
176
+
177
+ for filename in os.listdir(folder_path):
178
+ if pattern.match(filename):
179
+ image_path = os.path.join(folder_path, filename)
180
+ text = ocr_image(image_path)
181
+
182
+ # Save the OCR result to a text file in the subfolder
183
+ text_file_name = filename.replace('.jpg', '.txt')
184
+ text_file_path = os.path.join(ocr_text_folder, text_file_name)
185
+ with open(text_file_path, 'w') as file:
186
+ file.write(text)
187
+
188
+ #print(f"OCR result for {filename} saved to {text_file_path}\n")
189
+
190
+ def ocr_image(image_path):
191
+ image = Image.open(image_path)
192
+ text = pytesseract.image_to_string(image)
193
+ return text