petergpt commited on
Commit
4d7e87d
·
verified ·
1 Parent(s): 36a76ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -41
app.py CHANGED
@@ -4,12 +4,16 @@ from transformers import AutoModelForImageSegmentation
4
  from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
 
7
 
8
- # Load the model
9
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- birefnet.to(device)
12
- birefnet.eval()
 
 
 
13
 
14
  # Preprocessing
15
  image_size = (1024, 1024)
@@ -19,19 +23,20 @@ transform_image = transforms.Compose([
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
- def run_inference(images):
23
- # Convert all images into a batch tensor
24
  inputs = []
25
  original_sizes = []
26
  for img in images:
27
  original_sizes.append(img.size)
28
  inputs.append(transform_image(img))
29
  input_tensor = torch.stack(inputs).to(device)
30
-
31
- # Run inference
32
- with torch.no_grad():
33
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
34
-
 
 
35
  # Post-process
36
  results = []
37
  for i, img in enumerate(images):
@@ -41,43 +46,49 @@ def run_inference(images):
41
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
42
  result.paste(img, mask=mask)
43
  results.append(result)
 
 
 
 
44
  return results
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def extract_objects(filepaths):
47
  images = [Image.open(p).convert("RGB") for p in filepaths]
48
  start_time = time.time()
49
 
50
- # Attempt to process all at once
51
  try:
52
- results = run_inference(images)
53
  end_time = time.time()
54
  total_time = end_time - start_time
55
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
56
  return results, summary
57
  except torch.OutOfMemoryError:
58
- # Only if we fail, do we attempt to find a feasible batch size
59
- torch.cuda.empty_cache()
60
-
61
- fail_time = time.time()
62
- initial_attempt_time = fail_time - start_time
63
-
64
- # Binary search to find max feasible batch size
65
- low, high = 1, len(images)
66
- best = None
67
- best_count = 0
68
 
69
- while low <= high:
70
- mid = (low + high) // 2
71
- batch = images[:mid]
72
- try:
73
- res = run_inference(batch)
74
- best = res
75
- best_count = mid
76
- low = mid + 1 # try bigger
77
- except torch.OutOfMemoryError:
78
- torch.cuda.empty_cache()
79
- high = mid - 1 # try smaller
80
-
81
  end_time = time.time()
82
  total_time = end_time - start_time
83
 
@@ -86,14 +97,14 @@ def extract_objects(filepaths):
86
  summary = (
87
  f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
88
  f"Could not process even a single image.\n"
89
- f"Total time with fallback attempts: {total_time:.2f}s."
90
  )
91
  return [], summary
92
  else:
93
  summary = (
94
- f"Initial attempt OOM after {initial_attempt_time:.2f}s. "
95
- f"After fallback tests, found that {best_count} images can be processed.\n"
96
- f"Total time including fallback: {total_time:.2f}s.\n"
97
  f"Next time, try using up to {best_count} images."
98
  )
99
  return best, summary
@@ -103,7 +114,7 @@ iface = gr.Interface(
103
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
104
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
105
  title="BiRefNet Bulk Background Removal with On-Demand Fallback",
106
- description="Upload as many images as you want. If OOM occurs, a quick fallback will find the max feasible number of images without adding overhead unless needed."
107
  )
108
 
109
  iface.launch()
 
4
  from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
7
+ import gc
8
 
9
+ def load_model():
10
+ model = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ model.to(device)
13
+ model.eval()
14
+ return model, device
15
+
16
+ birefnet, device = load_model()
17
 
18
  # Preprocessing
19
  image_size = (1024, 1024)
 
23
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
  ])
25
 
26
+ def run_inference(images, model, device):
 
27
  inputs = []
28
  original_sizes = []
29
  for img in images:
30
  original_sizes.append(img.size)
31
  inputs.append(transform_image(img))
32
  input_tensor = torch.stack(inputs).to(device)
33
+ try:
34
+ with torch.no_grad():
35
+ preds = model(input_tensor)[-1].sigmoid().cpu()
36
+ except torch.OutOfMemoryError:
37
+ del input_tensor
38
+ torch.cuda.empty_cache()
39
+ raise
40
  # Post-process
41
  results = []
42
  for i, img in enumerate(images):
 
46
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
47
  result.paste(img, mask=mask)
48
  results.append(result)
49
+ # Cleanup
50
+ del input_tensor, preds
51
+ gc.collect()
52
+ torch.cuda.empty_cache()
53
  return results
54
 
55
+ def binary_search_max(images):
56
+ # After OOM, try to find max feasible batch
57
+ low, high = 1, len(images)
58
+ best = None
59
+ best_count = 0
60
+ while low <= high:
61
+ mid = (low + high) // 2
62
+ batch = images[:mid]
63
+ try:
64
+ # Re-load model to avoid leftover memory fragmentation
65
+ global birefnet, device
66
+ birefnet, device = load_model()
67
+ res = run_inference(batch, birefnet, device)
68
+ best = res
69
+ best_count = mid
70
+ low = mid + 1
71
+ except torch.OutOfMemoryError:
72
+ high = mid - 1
73
+ return best, best_count
74
+
75
  def extract_objects(filepaths):
76
  images = [Image.open(p).convert("RGB") for p in filepaths]
77
  start_time = time.time()
78
 
79
+ # First attempt: all images
80
  try:
81
+ results = run_inference(images, birefnet, device)
82
  end_time = time.time()
83
  total_time = end_time - start_time
84
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
85
  return results, summary
86
  except torch.OutOfMemoryError:
87
+ # OOM occurred, try to find feasible batch size now
88
+ oom_time = time.time()
89
+ initial_attempt_time = oom_time - start_time
 
 
 
 
 
 
 
90
 
91
+ best, best_count = binary_search_max(images)
 
 
 
 
 
 
 
 
 
 
 
92
  end_time = time.time()
93
  total_time = end_time - start_time
94
 
 
97
  summary = (
98
  f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
99
  f"Could not process even a single image.\n"
100
+ f"Total time including fallback attempts: {total_time:.2f}s."
101
  )
102
  return [], summary
103
  else:
104
  summary = (
105
+ f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
106
+ f"Found that {best_count} images can be processed without OOM.\n"
107
+ f"Total time including fallback attempts: {total_time:.2f}s.\n"
108
  f"Next time, try using up to {best_count} images."
109
  )
110
  return best, summary
 
114
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
115
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
116
  title="BiRefNet Bulk Background Removal with On-Demand Fallback",
117
+ description="Upload as many images as you want. If OOM occurs, a fallback will find the max feasible number. Extra cleanup steps and reinitialization for more consistent results."
118
  )
119
 
120
  iface.launch()