ankitkupadhyay commited on
Commit
5c9bc3a
1 Parent(s): 17b5157

Upload 2 files

Browse files

added app and requirement files

Files changed (2) hide show
  1. app.py +79 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import ViTFeatureExtractor, ViTModel, BertTokenizerFast, BertModel
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ class VisionLanguageModel(nn.Module):
8
+ def __init__(self):
9
+ super(VisionLanguageModel, self).__init__()
10
+ self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
11
+ self.language_model = BertModel.from_pretrained('bert-base-uncased')
12
+ self.classifier = nn.Linear(
13
+ self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
14
+ 2 # Number of classes: benign or malignant
15
+ )
16
+
17
+ def forward(self, input_ids, attention_mask, pixel_values):
18
+ vision_outputs = self.vision_model(pixel_values=pixel_values)
19
+ vision_pooled_output = vision_outputs.pooler_output
20
+
21
+ language_outputs = self.language_model(
22
+ input_ids=input_ids,
23
+ attention_mask=attention_mask
24
+ )
25
+ language_pooled_output = language_outputs.pooler_output
26
+
27
+ combined_features = torch.cat(
28
+ (vision_pooled_output, language_pooled_output),
29
+ dim=1
30
+ )
31
+
32
+ logits = self.classifier(combined_features)
33
+ return logits
34
+
35
+ # Load the model checkpoint
36
+ model = VisionLanguageModel()
37
+ model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
38
+ model.eval()
39
+
40
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
41
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
42
+
43
+ def predict(image, text_input):
44
+ # Preprocess the image
45
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values
46
+
47
+ # Preprocess the text
48
+ encoding = tokenizer(
49
+ text_input,
50
+ add_special_tokens=True,
51
+ max_length=256,
52
+ padding='max_length',
53
+ truncation=True,
54
+ return_tensors='pt'
55
+ )
56
+
57
+ # Make a prediction
58
+ with torch.no_grad():
59
+ outputs = model(
60
+ input_ids=encoding['input_ids'],
61
+ attention_mask=encoding['attention_mask'],
62
+ pixel_values=image
63
+ )
64
+ _, prediction = torch.max(outputs, dim=1)
65
+ return "Malignant" if prediction.item() == 1 else "Benign"
66
+
67
+ # Define Gradio interface
68
+ iface = gr.Interface(
69
+ fn=predict,
70
+ inputs=[
71
+ gr.inputs.Image(type="pil", label="Upload Skin Lesion Image"),
72
+ gr.inputs.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
73
+ ],
74
+ outputs="text",
75
+ title="Skin Lesion Classification Demo",
76
+ description="This model classifies skin lesions as benign or malignant based on an image and clinical information."
77
+ )
78
+
79
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio