top001 commited on
Commit
8795ec2
·
verified ·
1 Parent(s): 0f1c238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -21,16 +21,12 @@ def is_valid_image(file_content: bytes) -> Optional[str]:
21
  def process_image_bytes(image_bytes: bytes) -> np.ndarray:
22
  try:
23
  image = Image.open(io.BytesIO(image_bytes))
24
-
25
- if image.mode != 'RGB':
26
  image = image.convert('RGB')
27
-
28
  img_array = np.array(image)
29
- img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
30
-
31
  return img_array
32
  except Exception as e:
33
- raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")
34
 
35
  def get_mask(img, s=1024):
36
  img = (img / 255).astype(np.float32)
@@ -74,27 +70,25 @@ with gradio_app:
74
 
75
  @app.post("/remove-bg")
76
  async def remove_background(file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
77
  try:
78
- contents = await file.read()
79
-
80
- image_format = is_valid_image(contents)
81
- if not image_format or image_format not in SUPPORTED_FORMATS:
82
- raise HTTPException(
83
- status_code=400,
84
- detail=f"Unsupported image format. Supported formats: {', '.join(SUPPORTED_FORMATS)}"
85
- )
86
-
87
  img = process_image_bytes(contents)
88
  mask = get_mask(img)
 
 
 
89
 
90
- result = (mask * img).astype(np.uint8)
91
- result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
92
- alpha_channel = (mask * 255).astype(np.uint8)
93
- rgba = np.dstack((result, alpha_channel))
94
-
95
- pil_image = Image.fromarray(rgba, 'RGBA')
96
  img_byte_arr = io.BytesIO()
97
- pil_image.save(img_byte_arr, format='PNG', optimize=True)
98
  img_byte_arr = img_byte_arr.getvalue()
99
 
100
  return Response(
@@ -105,10 +99,8 @@ async def remove_background(file: UploadFile = File(...)):
105
  }
106
  )
107
 
108
- except HTTPException as he:
109
- raise he
110
  except Exception as e:
111
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
112
 
113
  if __name__ == "__main__":
114
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
21
  def process_image_bytes(image_bytes: bytes) -> np.ndarray:
22
  try:
23
  image = Image.open(io.BytesIO(image_bytes))
24
+ if image.mode == 'RGBA':
 
25
  image = image.convert('RGB')
 
26
  img_array = np.array(image)
 
 
27
  return img_array
28
  except Exception as e:
29
+ raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
30
 
31
  def get_mask(img, s=1024):
32
  img = (img / 255).astype(np.float32)
 
70
 
71
  @app.post("/remove-bg")
72
  async def remove_background(file: UploadFile = File(...)):
73
+ contents = await file.read()
74
+
75
+ image_format = is_valid_image(contents)
76
+ if not image_format or image_format not in SUPPORTED_FORMATS:
77
+ raise HTTPException(
78
+ status_code=400,
79
+ detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
80
+ )
81
+
82
  try:
 
 
 
 
 
 
 
 
 
83
  img = process_image_bytes(contents)
84
  mask = get_mask(img)
85
+ img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
86
+ mask = (mask * 255).astype(np.uint8)
87
+ img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
88
 
89
+ pil_image = Image.fromarray(img, 'RGBA')
 
 
 
 
 
90
  img_byte_arr = io.BytesIO()
91
+ pil_image.save(img_byte_arr, format='PNG')
92
  img_byte_arr = img_byte_arr.getvalue()
93
 
94
  return Response(
 
99
  }
100
  )
101
 
 
 
102
  except Exception as e:
103
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
104
 
105
  if __name__ == "__main__":
106
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']