Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +82 -0
- model_new.pth +3 -0
- requirements.txt +93 -0
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import timm
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
# Configuration and model definition
|
9 |
+
CONFIG = dict(
|
10 |
+
seed = 42,
|
11 |
+
model_name = 'tf_efficientnet_b4_ns',
|
12 |
+
train_batch_size = 16,
|
13 |
+
valid_batch_size = 32,
|
14 |
+
img_size = 256,
|
15 |
+
epochs = 5,
|
16 |
+
learning_rate = 1e-4,
|
17 |
+
scheduler = 'CosineAnnealingLR',
|
18 |
+
min_lr = 1e-6,
|
19 |
+
T_max = 100,
|
20 |
+
T_0 = 25,
|
21 |
+
warmup_epochs = 0,
|
22 |
+
weight_decay = 1e-6,
|
23 |
+
n_accumulate = 1,
|
24 |
+
n_fold = 5,
|
25 |
+
num_classes = 1,
|
26 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
27 |
+
competition = 'PetFinder',
|
28 |
+
_wandb_kernel = 'deb'
|
29 |
+
)
|
30 |
+
|
31 |
+
class PawpularityModel(nn.Module):
|
32 |
+
def __init__(self, model_name, pretrained=True):
|
33 |
+
super(PawpularityModel, self).__init__()
|
34 |
+
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
|
35 |
+
self.fc = nn.LazyLinear(CONFIG['num_classes'])
|
36 |
+
self.dropout = nn.Dropout(p=0.3)
|
37 |
+
|
38 |
+
def forward(self, images, meta):
|
39 |
+
features = self.model(images) # Extract features
|
40 |
+
features = self.dropout(features)
|
41 |
+
features = torch.cat([features, meta], dim=1) # Concatenate metadata
|
42 |
+
output = self.fc(features) # Predict Pawpularity
|
43 |
+
return output
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
# Load the model
|
48 |
+
model = PawpularityModel(CONFIG['model_name'])
|
49 |
+
model.load_state_dict(torch.load('model_new.pth', map_location=CONFIG['device']))
|
50 |
+
model.to(CONFIG['device'])
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
# Define image transformation
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.Resize((256, 256)),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
58 |
+
])
|
59 |
+
|
60 |
+
|
61 |
+
st.title("Pawpularity Score Prediction 🐾")
|
62 |
+
st.write("Project by Shreya Sivakumar-20BCE1794")
|
63 |
+
|
64 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
65 |
+
if uploaded_file is not None:
|
66 |
+
image = Image.open(uploaded_file).convert('RGB')
|
67 |
+
st.image(image, caption='Uploaded Image', use_column_width=True)
|
68 |
+
|
69 |
+
# Preprocess the image and prepare dummy metadata (replace with actual metadata handling)
|
70 |
+
image = transform(image).unsqueeze(0).to(CONFIG['device'])
|
71 |
+
meta = torch.zeros((1, 12)).to(CONFIG['device'])
|
72 |
+
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
output = model(image, meta)
|
76 |
+
pawpularity_score = output.item()
|
77 |
+
|
78 |
+
st.markdown(f"<h2 style='text-align: center; color: black;'>🐾 Pawpularity Score: {pawpularity_score}</h1>", unsafe_allow_html=True)
|
79 |
+
st.markdown("""
|
80 |
+
---
|
81 |
+
Copyright © 2024 Shreya Sivakumar. All rights reserved.
|
82 |
+
""")
|
model_new.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d15b049e176bee74a9480ad9208dc43d7cb7bf0ecc2721f803a53ec37fadc11d
|
3 |
+
size 70948090
|
requirements.txt
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
altair==5.2.0
|
3 |
+
astunparse==1.6.3
|
4 |
+
attrs==23.2.0
|
5 |
+
blinker==1.7.0
|
6 |
+
cachetools==5.3.3
|
7 |
+
certifi==2024.2.2
|
8 |
+
charset-normalizer==3.3.2
|
9 |
+
click==8.1.7
|
10 |
+
contourpy==1.2.0
|
11 |
+
cycler==0.12.1
|
12 |
+
filelock==3.13.1
|
13 |
+
flatbuffers==24.3.7
|
14 |
+
fonttools==4.50.0
|
15 |
+
fsspec==2024.3.1
|
16 |
+
gast==0.5.4
|
17 |
+
gitdb==4.0.11
|
18 |
+
GitPython==3.1.42
|
19 |
+
google-pasta==0.2.0
|
20 |
+
grpcio==1.62.1
|
21 |
+
h5py==3.10.0
|
22 |
+
huggingface-hub==0.21.4
|
23 |
+
idna==3.6
|
24 |
+
Jinja2==3.1.3
|
25 |
+
jsonschema==4.21.1
|
26 |
+
jsonschema-specifications==2023.12.1
|
27 |
+
keras==3.1.1
|
28 |
+
kiwisolver==1.4.5
|
29 |
+
libclang==18.1.1
|
30 |
+
Markdown==3.6
|
31 |
+
markdown-it-py==3.0.0
|
32 |
+
MarkupSafe==2.1.5
|
33 |
+
matplotlib==3.8.3
|
34 |
+
mdurl==0.1.2
|
35 |
+
ml-dtypes==0.3.2
|
36 |
+
mpmath==1.3.0
|
37 |
+
namex==0.0.7
|
38 |
+
networkx==3.2.1
|
39 |
+
numpy==1.26.4
|
40 |
+
nvidia-cublas-cu12==12.1.3.1
|
41 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
42 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
43 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
44 |
+
nvidia-cudnn-cu12==8.9.2.26
|
45 |
+
nvidia-cufft-cu12==11.0.2.54
|
46 |
+
nvidia-curand-cu12==10.3.2.106
|
47 |
+
nvidia-cusolver-cu12==11.4.5.107
|
48 |
+
nvidia-cusparse-cu12==12.1.0.106
|
49 |
+
nvidia-nccl-cu12==2.19.3
|
50 |
+
nvidia-nvjitlink-cu12==12.4.99
|
51 |
+
nvidia-nvtx-cu12==12.1.105
|
52 |
+
opt-einsum==3.3.0
|
53 |
+
optree==0.10.0
|
54 |
+
packaging==23.2
|
55 |
+
pandas==2.2.1
|
56 |
+
pillow==10.2.0
|
57 |
+
protobuf==4.25.3
|
58 |
+
pyarrow==15.0.2
|
59 |
+
pydeck==0.8.1b0
|
60 |
+
Pygments==2.17.2
|
61 |
+
pyparsing==3.1.2
|
62 |
+
python-dateutil==2.9.0.post0
|
63 |
+
pytz==2024.1
|
64 |
+
PyYAML==6.0.1
|
65 |
+
referencing==0.34.0
|
66 |
+
requests==2.31.0
|
67 |
+
rich==13.7.1
|
68 |
+
rpds-py==0.18.0
|
69 |
+
safetensors==0.4.2
|
70 |
+
six==1.16.0
|
71 |
+
smmap==5.0.1
|
72 |
+
streamlit==1.32.2
|
73 |
+
sympy==1.12
|
74 |
+
tenacity==8.2.3
|
75 |
+
tensorboard==2.16.2
|
76 |
+
tensorboard-data-server==0.7.2
|
77 |
+
tensorflow==2.16.1
|
78 |
+
tensorflow-io-gcs-filesystem==0.36.0
|
79 |
+
termcolor==2.4.0
|
80 |
+
timm==0.9.16
|
81 |
+
toml==0.10.2
|
82 |
+
toolz==0.12.1
|
83 |
+
torch==2.2.1
|
84 |
+
torchvision==0.17.1
|
85 |
+
tornado==6.4
|
86 |
+
tqdm==4.66.2
|
87 |
+
triton==2.2.0
|
88 |
+
typing_extensions==4.10.0
|
89 |
+
tzdata==2024.1
|
90 |
+
urllib3==2.2.1
|
91 |
+
watchdog==4.0.0
|
92 |
+
Werkzeug==3.0.1
|
93 |
+
wrapt==1.16.0
|