File size: 2,602 Bytes
485f4da
63dbd6c
485f4da
d1eee78
cb4264c
 
d1eee78
485f4da
0747848
d1eee78
 
0747848
73142ed
 
0747848
73142ed
 
0747848
d1eee78
0747848
10cb574
0747848
73142ed
d1eee78
0747848
73142ed
 
0747848
cb4264c
0747848
d1eee78
 
24fc650
485f4da
0747848
73142ed
cb4264c
 
 
 
 
 
 
 
0747848
cb4264c
 
73142ed
 
cb4264c
73142ed
cb4264c
 
63dbd6c
 
485f4da
 
cb4264c
 
0747848
cb4264c
0747848
cb4264c
 
 
 
73142ed
cb4264c
73142ed
cb4264c
 
 
 
 
485f4da
cb4264c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download the fine-tuned VGG16 model
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")

# Load the VGG16 model with pre-trained weights
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
    param.requires_grad = False  # Freeze parameters

# Update the last fully connected layer to match the number of classes
num_classes = 8
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
vgg16 = vgg16.to(device)

# Load the fine-tuned model state
checkpoint = torch.load(vgg16_model_path, map_location=device)
vgg16.load_state_dict(checkpoint['model_state_dict'])
vgg16.eval()  # Set the model to evaluation mode

# Load the fine-tuned Stable Diffusion model
model_id = "sk2003/room-styler"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)

# Prediction function for the VGG16 model
def predict(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = vgg16(image_tensor)
        _, predicted = torch.max(outputs.data, 1)

    classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
    pred = classes[predicted.item()]

    return pred

# Generation function for the Stable Diffusion model
def generate_image(prompt):
    image = pipe(prompt).images[0]
    return image

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Room Style Recognition and Generation")  # Title

    # 1st tab
    with gr.Tab("Recognize Room Style"):
        image_input = gr.Image(type="pil")
        label_output = gr.Textbox()
        btn_predict = gr.Button("Predict Style")
        btn_predict.click(predict, inputs=image_input, outputs=label_output)

    # 2nd tab
    with gr.Tab("Generate Room Style"):
        text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
        image_output = gr.Image()
        btn_generate = gr.Button("Generate Image")
        btn_generate.click(generate_image, inputs=text_input, outputs=image_output)

demo.launch()