File size: 2,934 Bytes
01f0a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce309f9
01f0a3d
 
ce309f9
01f0a3d
ce309f9
 
 
 
 
01f0a3d
 
 
ce309f9
01f0a3d
ce309f9
01f0a3d
ce309f9
 
01f0a3d
ce309f9
01f0a3d
ce309f9
01f0a3d
 
 
ce309f9
 
 
 
 
 
 
 
 
 
01f0a3d
ce309f9
 
 
 
 
 
01f0a3d
 
 
 
 
 
 
 
ce309f9
 
01f0a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
ce309f9
01f0a3d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear

class CNNKAN(nn.Module):
    def __init__(self):
        super(CNNKAN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.5)
        self.kan1 = KANLinear(128 * 25 * 25, 256)
        self.kan2 = KANLinear(256, 1)

    def forward(self, x):
        x = F.selu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.selu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.selu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.kan1(x)
        x = self.dropout(x)
        x = self.kan2(x)
        return x

def load_model(weights_path, device):
    model = CNNKAN().to(device)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()
    return model

def load_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    return img

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((200, 200)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)

# Streamlit app
st.title("Image Classification with CNN-KAN")

st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_weights_KAN.pth', device)

img = None

if uploaded_file is not None:
    img = Image.open(uploaded_file).convert('RGB')
elif image_url:
    try:
        img = load_image_from_url(image_url)
    except Exception as e:
        st.sidebar.error(f"Error loading image from URL: {e}")

if img is not None:
    st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
    if st.button('Predict'):
        img_tensor = preprocess_image(img).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            prob = torch.sigmoid(output).item()

        st.write(f"Prediction: {prob:.4f}")

        if prob < 0.5:
            st.write("This image is classified as a dandelion flower.")
        else:
            st.write("This image is classified as grass.")