CindyBSydney commited on
Commit
61a1204
1 Parent(s): 742f6ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -40
app.py CHANGED
@@ -3,14 +3,16 @@ import torchvision.transforms as transforms
3
  import torchvision.models as models
4
  import torch.nn as nn
5
  from joblib import load
6
- from gradio import File
7
  from PIL import Image
8
- import gradio as gr
9
  import matplotlib.pyplot as plt
10
  import io
 
 
11
 
12
- # Transformation and device setup
13
  device = torch.device("cpu")
 
 
14
  data_transforms = transforms.Compose([
15
  transforms.Resize(224),
16
  transforms.CenterCrop(224),
@@ -19,57 +21,77 @@ data_transforms = transforms.Compose([
19
  ])
20
 
21
  # Load the Isolation Forest model
22
- clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')
23
-
24
- # Load feature extractor
25
- feature_extractor_path = 'Models/feature_extractor.pth'
26
- feature_extractor = models.resnet50(weights=None)
27
- feature_extractor.fc = nn.Sequential()
28
- feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
29
- feature_extractor.to(device)
30
- feature_extractor.eval()
31
-
32
- # Load gastric classification model
33
- GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth'
34
- model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device)
35
- model_ft.to(device)
36
- model_ft.eval()
37
 
38
- # Anomaly detection function
39
- def is_anomaly(clf, feature_extractor, input_image):
 
 
 
 
40
  feature_extractor.to(device)
41
- with torch.no_grad():
42
- image_features = feature_extractor(input_image)
43
-
44
- is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1))
45
- return is_outlier[0] == -1
46
 
47
- # Anomaly detection and classification function
48
- def classify_image(uploaded_image):
49
- image = Image.open(uploaded_image).convert('RGB')
50
- input_image = data_transforms(image).unsqueeze(0).to(device)
51
-
52
- # Anomaly detection
53
- if is_anomaly(clf, feature_extractor, input_image):
54
- return "Anomaly detected. Image will not be classified.", None
55
 
56
- # Classification
 
57
  with torch.no_grad():
58
- outputs = model_ft(input_image)
59
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
60
  _, predicted = torch.max(outputs, 1)
61
 
62
- predicted_class_index = predicted.item()
63
  class_names = ['abnormal', 'normal']
 
64
  predicted_class_name = class_names[predicted_class_index]
65
  predicted_probability = probabilities[0][predicted_class_index].item() * 100
66
 
67
- return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
69
  iface = gr.Interface(
70
- fn=classify_image,
71
  inputs=File(type="filepath"),
72
- outputs=gr.Image(),
73
  title="GastroHub AI Gastric Image Classifier",
74
  description="Upload an image to classify it as normal or abnormal.",
75
  article="Above is a sample image to test the results of the model. Click it to see the results.",
@@ -79,5 +101,4 @@ iface = gr.Interface(
79
  allow_flagging="never",
80
  )
81
 
82
- # Run the Gradio app
83
  iface.launch()
 
3
  import torchvision.models as models
4
  import torch.nn as nn
5
  from joblib import load
 
6
  from PIL import Image
 
7
  import matplotlib.pyplot as plt
8
  import io
9
+ import numpy as np
10
+ import gradio as gr
11
 
12
+ # Device configuration
13
  device = torch.device("cpu")
14
+
15
+ # Transformation for the input images
16
  data_transforms = transforms.Compose([
17
  transforms.Resize(224),
18
  transforms.CenterCrop(224),
 
21
  ])
22
 
23
  # Load the Isolation Forest model
24
+ def load_isolation_forest():
25
+ path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib'
26
+ return load(path)
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Load the feature extractor
29
+ def load_feature_extractor():
30
+ feature_extractor_path = 'Models/feature_extractor.pth'
31
+ feature_extractor = models.resnet50(weights=None)
32
+ feature_extractor.fc = nn.Sequential()
33
+ feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
34
  feature_extractor.to(device)
35
+ feature_extractor.eval()
36
+ return feature_extractor
 
 
 
37
 
38
+ # Anomaly detection function
39
+ def is_anomaly(clf, feature_extractor, image):
40
+ with torch.no_grad():
41
+ image_features = feature_extractor(image)
42
+ return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1
 
 
 
43
 
44
+ # Classification function
45
+ def classify_image(model, image):
46
  with torch.no_grad():
47
+ outputs = model(image)
48
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
49
  _, predicted = torch.max(outputs, 1)
50
 
 
51
  class_names = ['abnormal', 'normal']
52
+ predicted_class_index = predicted.item()
53
  predicted_class_name = class_names[predicted_class_index]
54
  predicted_probability = probabilities[0][predicted_class_index].item() * 100
55
 
56
+ return predicted_class_name, predicted_probability
57
+
58
+ # Load the classification model
59
+ def load_classification_model():
60
+ model_path = 'Gastric_Models/the_resnet_50_model.pth'
61
+ model = torch.load(model_path, map_location=device)
62
+ model.to(device)
63
+ model.eval()
64
+ return model
65
+
66
+ # Function to process the image and get results
67
+ def process_image(image_path):
68
+ # Convert to PIL and apply transforms
69
+ image = Image.open(io.BytesIO(image_path.read())).convert('RGB')
70
+ input_image = data_transforms(image).unsqueeze(0).to(device)
71
+
72
+ # Load models
73
+ clf = load_isolation_forest()
74
+ feature_extractor = load_feature_extractor()
75
+ classification_model = load_classification_model()
76
+
77
+ # Check for anomaly
78
+ if is_anomaly(clf, feature_extractor, input_image):
79
+ return "Anomaly detected. Image will not be classified.", None, None
80
+
81
+ # Classify image
82
+ predicted_class, probability = classify_image(classification_model, input_image)
83
+ result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%"
84
+
85
+ # Further processing for heatmap or additional features can be added here
86
+
87
+ return result, None, None # Returning placeholders for additional outputs if needed
88
 
89
+ # Gradio interface
90
+
91
  iface = gr.Interface(
92
+ fn=process_image,
93
  inputs=File(type="filepath"),
94
+ outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")],
95
  title="GastroHub AI Gastric Image Classifier",
96
  description="Upload an image to classify it as normal or abnormal.",
97
  article="Above is a sample image to test the results of the model. Click it to see the results.",
 
101
  allow_flagging="never",
102
  )
103
 
 
104
  iface.launch()