lukmanaj commited on
Commit
542fff6
·
verified ·
1 Parent(s): e4d1a99

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements (1).txt +4 -0
  2. streamlit-app.py +67 -0
requirements (1).txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ streamlit
4
+ Pillow>=8.0.0
streamlit-app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import io
7
+
8
+ # Define the function to load the model
9
+ def load_model(model_path, device):
10
+ weights = torchvision.models.DenseNet201_Weights.DEFAULT # best available weight
11
+ model = torchvision.models.densenet201(weights=weights).to(device)
12
+ model.classifier = torch.nn.Sequential(
13
+ torch.nn.Dropout(p=0.2, inplace=True),
14
+ torch.nn.Linear(in_features=1920, out_features=4, bias=True)
15
+ ).to(device)
16
+ model.load_state_dict(torch.load(model_path, map_location=device))
17
+ model.to(device)
18
+ model.eval()
19
+ return model
20
+
21
+ # Define the function for preprocessing the image
22
+ def preprocess_image(image):
23
+ transform = transforms.Compose([
24
+ transforms.Resize(64),
25
+ transforms.ToTensor(),
26
+
27
+ ])
28
+ return transform(image)
29
+
30
+ # Define the function for getting predictions
31
+ def get_prediction(model, image, device):
32
+ class_names = ['buffalo', 'elephant', 'rhino', 'zebra']
33
+ image = image.unsqueeze(0).to(device) # Add batch dimension and move to device
34
+ with torch.no_grad():
35
+ pred_logits = model(image)
36
+ pred_prob = torch.softmax(pred_logits, dim=1)
37
+ pred_label = torch.argmax(pred_prob, dim=1)
38
+ return class_names[pred_label.item()], pred_prob.max().item()
39
+
40
+ # Streamlit app starts here
41
+ st.title("Wild Animal Prediction App")
42
+
43
+ uploaded_file = st.file_uploader("Upload an image of one of the following: Bufallo, Elephant, Rhino, or Zebra", type=["jpg", "jpeg", "png"])
44
+ if uploaded_file is not None:
45
+ # Convert the file-like object to bytes, then open it with PIL
46
+ image_bytes = uploaded_file.getvalue()
47
+ image = Image.open(io.BytesIO(image_bytes))
48
+
49
+ # Display the uploaded image
50
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
51
+
52
+ # Predict button
53
+ if st.button('Predict'):
54
+ # Set device
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
+ # Load the model
58
+ model_path = 'model/densenetafri.pth' # Fixed model path
59
+ model = load_model(model_path, device)
60
+
61
+ # Preprocess the image and predict
62
+ preprocessed_image = preprocess_image(image)
63
+ prediction, probability = get_prediction(model, preprocessed_image, device)
64
+
65
+ # Display the prediction
66
+ st.write(f"Prediction: {prediction}, Probability: {probability:.3f}")
67
+