AI-Manith commited on
Commit
789d928
1 Parent(s): 2ec1207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -42
app.py CHANGED
@@ -6,50 +6,38 @@ from facenet_pytorch import MTCNN
6
  from torchvision.transforms.functional import to_pil_image
7
 
8
  # Function to load the ViT model and MTCNN
9
- def load_model_and_mtcnn(model_path):
10
  model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  model.to(device)
13
- mtcnn = MTCNN(keep_all=True, device=device)
14
- return model, device, mtcnn
15
 
16
- def detect_and_process_skin(image_bytes):
17
- """Detects faces in an image, crops the skin region, and returns it as an image object."""
18
- # Load image from bytes
19
- img = Image.open(io.BytesIO(image_bytes))
20
- img_np = np.array(img)
21
- img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
22
 
23
- # Detect faces in the image
24
- detections = mtcnn.detect_faces(img_rgb)
 
 
 
 
25
 
26
- # Check if any faces were detected
27
- if detections:
28
- x, y, width, height = detections[0]['box']
29
 
30
- # Crop the face region
31
- face_img_np = img_np[y:y+height, x:x+width]
32
-
33
- # Convert to PIL Image for return
34
- pil_img = Image.fromarray(face_img_np)
35
- return pil_img
 
 
36
  else:
37
- # Return original image if no face was detected
38
- return img
39
 
40
- # Function to preprocess the image and return both the tensor and the final PIL image for display
41
- def preprocess_image(image, mtcnn, device):
42
- processed_image = image # Initialize with the original image
43
- try:
44
- # Directly call mtcnn with the image to get cropped faces
45
- cropped_faces = mtcnn(image)
46
- if cropped_faces is not None and len(cropped_faces) > 0:
47
- # Convert the first detected face tensor back to PIL Image for further processing
48
- processed_image = to_pil_image(cropped_faces[0].cpu(),mode='BGR;16')
49
- except Exception as e:
50
- st.write(f"Exception in face detection: {e}")
51
- processed_image = image
52
-
53
  transform = transforms.Compose([
54
  transforms.Resize((224, 224)),
55
  transforms.ToTensor(),
@@ -80,14 +68,9 @@ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png
80
  if uploaded_file is not None:
81
  image = Image.open(uploaded_file).convert("RGB")
82
  st.image(image, caption='Uploaded Image', use_column_width=True)
83
- image1 = image.getvalue()
84
- image_ten = detect_and_process_skin(image1)
85
  image_tensor, final_image = preprocess_image(image, mtcnn, device)
86
- predicted_class, probabilities = predict(image_ten, model, device)
87
 
88
  st.write(f"Predicted class: {predicted_class.item()}")
89
  # Display the final processed image
90
- # st.image(final_image, caption='Processed Image', use_column_width=True)
91
- img_bytes = io.BytesIO()
92
- detect_and_process_skin(image1.getvalue()).save(img_bytes, format='JPEG')
93
- st.image(img_bytes.getvalue(), width=250, caption="Processed Image")
 
6
  from torchvision.transforms.functional import to_pil_image
7
 
8
  # Function to load the ViT model and MTCNN
9
+ def load_model(model_path):
10
  model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  model.to(device)
13
+ return model, device
 
14
 
15
+ # Initialize MTCNN for face detection
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ mtcnn = MTCNN(keep_all=True, device=device)
 
 
 
18
 
19
+ # Function to preprocess the image using MTCNN for face detection
20
+ def preprocess_image(image, device):
21
+ # Convert PIL image to OpenCV format
22
+ open_cv_image = np.array(image)
23
+ # Convert RGB to BGR for OpenCV
24
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_RGB2BGR)
25
 
26
+ # Convert OpenCV image back to PIL Image for MTCNN
27
+ pil_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
 
28
 
29
+ # Use MTCNN to detect faces
30
+ boxes, _ = mtcnn.detect(pil_image)
31
+ if boxes is not None:
32
+ # Crop the first detected face (for simplicity)
33
+ box = boxes[0].astype(int)
34
+ cropped_face = open_cv_image[box[1]:box[3], box[0]:box[2]]
35
+ # Convert cropped face back to PIL for further processing
36
+ processed_image = Image.fromarray(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
37
  else:
38
+ processed_image = image # Use the original image if no face is detected
 
39
 
40
+ # Transform image for model
 
 
 
 
 
 
 
 
 
 
 
 
41
  transform = transforms.Compose([
42
  transforms.Resize((224, 224)),
43
  transforms.ToTensor(),
 
68
  if uploaded_file is not None:
69
  image = Image.open(uploaded_file).convert("RGB")
70
  st.image(image, caption='Uploaded Image', use_column_width=True)
 
 
71
  image_tensor, final_image = preprocess_image(image, mtcnn, device)
72
+ predicted_class, probabilities = predict(image_tensor, model, device)
73
 
74
  st.write(f"Predicted class: {predicted_class.item()}")
75
  # Display the final processed image
76
+ st.image(final_image, caption='Processed Image', use_column_width=True)