ankitkupadhyay commited on
Commit
a1ee699
·
verified ·
1 Parent(s): 9807690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
 
2
  import torch.nn as nn
3
- from transformers import ViTFeatureExtractor, ViTModel, BertTokenizerFast, BertModel
4
  from PIL import Image
5
  import gradio as gr
6
 
@@ -32,13 +33,13 @@ class VisionLanguageModel(nn.Module):
32
  logits = self.classifier(combined_features)
33
  return logits
34
 
35
- # Load the model checkpoint
36
  model = VisionLanguageModel()
37
- model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
38
  model.eval()
39
 
40
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
41
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
42
 
43
  def predict(image, text_input):
44
  # Preprocess the image
@@ -64,12 +65,12 @@ def predict(image, text_input):
64
  _, prediction = torch.max(outputs, dim=1)
65
  return "Malignant" if prediction.item() == 1 else "Benign"
66
 
67
- # Define Gradio interface
68
  iface = gr.Interface(
69
  fn=predict,
70
  inputs=[
71
- gr.inputs.Image(type="pil", label="Upload Skin Lesion Image"),
72
- gr.inputs.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
73
  ],
74
  outputs="text",
75
  title="Skin Lesion Classification Demo",
@@ -77,3 +78,4 @@ iface = gr.Interface(
77
  )
78
 
79
  iface.launch()
 
 
1
  import torch
2
+ import torch
3
  import torch.nn as nn
4
+ from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
5
  from PIL import Image
6
  import gradio as gr
7
 
 
33
  logits = self.classifier(combined_features)
34
  return logits
35
 
36
+ # Load the model checkpoint with safer loading
37
  model = VisionLanguageModel()
38
+ model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
39
  model.eval()
40
 
41
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
42
+ feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
43
 
44
  def predict(image, text_input):
45
  # Preprocess the image
 
65
  _, prediction = torch.max(outputs, dim=1)
66
  return "Malignant" if prediction.item() == 1 else "Benign"
67
 
68
+ # Define Gradio interface with updated component syntax
69
  iface = gr.Interface(
70
  fn=predict,
71
  inputs=[
72
+ gr.Image(type="pil", label="Upload Skin Lesion Image"),
73
+ gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
74
  ],
75
  outputs="text",
76
  title="Skin Lesion Classification Demo",
 
78
  )
79
 
80
  iface.launch()
81
+