shreyasiv commited on
Commit
68ab9e2
1 Parent(s): 3e6b85c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +82 -0
  2. model_new.pth +3 -0
  3. requirements.txt +93 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import timm
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ # Configuration and model definition
9
+ CONFIG = dict(
10
+ seed = 42,
11
+ model_name = 'tf_efficientnet_b4_ns',
12
+ train_batch_size = 16,
13
+ valid_batch_size = 32,
14
+ img_size = 256,
15
+ epochs = 5,
16
+ learning_rate = 1e-4,
17
+ scheduler = 'CosineAnnealingLR',
18
+ min_lr = 1e-6,
19
+ T_max = 100,
20
+ T_0 = 25,
21
+ warmup_epochs = 0,
22
+ weight_decay = 1e-6,
23
+ n_accumulate = 1,
24
+ n_fold = 5,
25
+ num_classes = 1,
26
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
27
+ competition = 'PetFinder',
28
+ _wandb_kernel = 'deb'
29
+ )
30
+
31
+ class PawpularityModel(nn.Module):
32
+ def __init__(self, model_name, pretrained=True):
33
+ super(PawpularityModel, self).__init__()
34
+ self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
35
+ self.fc = nn.LazyLinear(CONFIG['num_classes'])
36
+ self.dropout = nn.Dropout(p=0.3)
37
+
38
+ def forward(self, images, meta):
39
+ features = self.model(images) # Extract features
40
+ features = self.dropout(features)
41
+ features = torch.cat([features, meta], dim=1) # Concatenate metadata
42
+ output = self.fc(features) # Predict Pawpularity
43
+ return output
44
+
45
+
46
+
47
+ # Load the model
48
+ model = PawpularityModel(CONFIG['model_name'])
49
+ model.load_state_dict(torch.load('model_new.pth', map_location=CONFIG['device']))
50
+ model.to(CONFIG['device'])
51
+ model.eval()
52
+
53
+ # Define image transformation
54
+ transform = transforms.Compose([
55
+ transforms.Resize((256, 256)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
+ ])
59
+
60
+
61
+ st.title("Pawpularity Score Prediction 🐾")
62
+ st.write("Project by Shreya Sivakumar-20BCE1794")
63
+
64
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
65
+ if uploaded_file is not None:
66
+ image = Image.open(uploaded_file).convert('RGB')
67
+ st.image(image, caption='Uploaded Image', use_column_width=True)
68
+
69
+ # Preprocess the image and prepare dummy metadata (replace with actual metadata handling)
70
+ image = transform(image).unsqueeze(0).to(CONFIG['device'])
71
+ meta = torch.zeros((1, 12)).to(CONFIG['device'])
72
+
73
+
74
+ with torch.no_grad():
75
+ output = model(image, meta)
76
+ pawpularity_score = output.item()
77
+
78
+ st.markdown(f"<h2 style='text-align: center; color: black;'>🐾 Pawpularity Score: {pawpularity_score}</h1>", unsafe_allow_html=True)
79
+ st.markdown("""
80
+ ---
81
+ Copyright © 2024 Shreya Sivakumar. All rights reserved.
82
+ """)
model_new.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d15b049e176bee74a9480ad9208dc43d7cb7bf0ecc2721f803a53ec37fadc11d
3
+ size 70948090
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ altair==5.2.0
3
+ astunparse==1.6.3
4
+ attrs==23.2.0
5
+ blinker==1.7.0
6
+ cachetools==5.3.3
7
+ certifi==2024.2.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ contourpy==1.2.0
11
+ cycler==0.12.1
12
+ filelock==3.13.1
13
+ flatbuffers==24.3.7
14
+ fonttools==4.50.0
15
+ fsspec==2024.3.1
16
+ gast==0.5.4
17
+ gitdb==4.0.11
18
+ GitPython==3.1.42
19
+ google-pasta==0.2.0
20
+ grpcio==1.62.1
21
+ h5py==3.10.0
22
+ huggingface-hub==0.21.4
23
+ idna==3.6
24
+ Jinja2==3.1.3
25
+ jsonschema==4.21.1
26
+ jsonschema-specifications==2023.12.1
27
+ keras==3.1.1
28
+ kiwisolver==1.4.5
29
+ libclang==18.1.1
30
+ Markdown==3.6
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.5
33
+ matplotlib==3.8.3
34
+ mdurl==0.1.2
35
+ ml-dtypes==0.3.2
36
+ mpmath==1.3.0
37
+ namex==0.0.7
38
+ networkx==3.2.1
39
+ numpy==1.26.4
40
+ nvidia-cublas-cu12==12.1.3.1
41
+ nvidia-cuda-cupti-cu12==12.1.105
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu12==12.1.105
44
+ nvidia-cudnn-cu12==8.9.2.26
45
+ nvidia-cufft-cu12==11.0.2.54
46
+ nvidia-curand-cu12==10.3.2.106
47
+ nvidia-cusolver-cu12==11.4.5.107
48
+ nvidia-cusparse-cu12==12.1.0.106
49
+ nvidia-nccl-cu12==2.19.3
50
+ nvidia-nvjitlink-cu12==12.4.99
51
+ nvidia-nvtx-cu12==12.1.105
52
+ opt-einsum==3.3.0
53
+ optree==0.10.0
54
+ packaging==23.2
55
+ pandas==2.2.1
56
+ pillow==10.2.0
57
+ protobuf==4.25.3
58
+ pyarrow==15.0.2
59
+ pydeck==0.8.1b0
60
+ Pygments==2.17.2
61
+ pyparsing==3.1.2
62
+ python-dateutil==2.9.0.post0
63
+ pytz==2024.1
64
+ PyYAML==6.0.1
65
+ referencing==0.34.0
66
+ requests==2.31.0
67
+ rich==13.7.1
68
+ rpds-py==0.18.0
69
+ safetensors==0.4.2
70
+ six==1.16.0
71
+ smmap==5.0.1
72
+ streamlit==1.32.2
73
+ sympy==1.12
74
+ tenacity==8.2.3
75
+ tensorboard==2.16.2
76
+ tensorboard-data-server==0.7.2
77
+ tensorflow==2.16.1
78
+ tensorflow-io-gcs-filesystem==0.36.0
79
+ termcolor==2.4.0
80
+ timm==0.9.16
81
+ toml==0.10.2
82
+ toolz==0.12.1
83
+ torch==2.2.1
84
+ torchvision==0.17.1
85
+ tornado==6.4
86
+ tqdm==4.66.2
87
+ triton==2.2.0
88
+ typing_extensions==4.10.0
89
+ tzdata==2024.1
90
+ urllib3==2.2.1
91
+ watchdog==4.0.0
92
+ Werkzeug==3.0.1
93
+ wrapt==1.16.0