import streamlit as st import torch import torch.nn as nn import timm from PIL import Image from torchvision import transforms # Configuration and model definition CONFIG = dict( seed = 42, model_name = 'tf_efficientnet_b4_ns', train_batch_size = 16, valid_batch_size = 32, img_size = 256, epochs = 5, learning_rate = 1e-4, scheduler = 'CosineAnnealingLR', min_lr = 1e-6, T_max = 100, T_0 = 25, warmup_epochs = 0, weight_decay = 1e-6, n_accumulate = 1, n_fold = 5, num_classes = 1, device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), competition = 'PetFinder', _wandb_kernel = 'deb' ) class PawpularityModel(nn.Module): def __init__(self, model_name, pretrained=True): super(PawpularityModel, self).__init__() self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.fc = nn.LazyLinear(CONFIG['num_classes']) self.dropout = nn.Dropout(p=0.3) def forward(self, images, meta): features = self.model(images) # Extract features features = self.dropout(features) features = torch.cat([features, meta], dim=1) # Concatenate metadata output = self.fc(features) # Predict Pawpularity return output # Load the model model = PawpularityModel(CONFIG['model_name']) model.load_state_dict(torch.load('model_new.pth', map_location=CONFIG['device'])) model.to(CONFIG['device']) model.eval() # Define image transformation transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) st.title("Pawpularity Score Prediction 🐾") st.write("Project by Shreya Sivakumar-20BCE1794") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) # Preprocess the image and prepare dummy metadata (replace with actual metadata handling) image = transform(image).unsqueeze(0).to(CONFIG['device']) meta = torch.zeros((1, 12)).to(CONFIG['device']) with torch.no_grad(): output = model(image, meta) pawpularity_score = output.item() st.markdown(f"

🐾 Pawpularity Score: {pawpularity_score}

", unsafe_allow_html=True) st.markdown(""" --- Copyright © 2024 Shreya Sivakumar. All rights reserved. """)