InteriorVision / app.py
sk2003's picture
Update app.py
0747848 verified
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the fine-tuned VGG16 model
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
# Load the VGG16 model with pre-trained weights
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
param.requires_grad = False # Freeze parameters
# Update the last fully connected layer to match the number of classes
num_classes = 8
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
vgg16 = vgg16.to(device)
# Load the fine-tuned model state
checkpoint = torch.load(vgg16_model_path, map_location=device)
vgg16.load_state_dict(checkpoint['model_state_dict'])
vgg16.eval() # Set the model to evaluation mode
# Load the fine-tuned Stable Diffusion model
model_id = "sk2003/room-styler"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)
# Prediction function for the VGG16 model
def predict(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = vgg16(image_tensor)
_, predicted = torch.max(outputs.data, 1)
classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
pred = classes[predicted.item()]
return pred
# Generation function for the Stable Diffusion model
def generate_image(prompt):
image = pipe(prompt).images[0]
return image
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Room Style Recognition and Generation") # Title
# 1st tab
with gr.Tab("Recognize Room Style"):
image_input = gr.Image(type="pil")
label_output = gr.Textbox()
btn_predict = gr.Button("Predict Style")
btn_predict.click(predict, inputs=image_input, outputs=label_output)
# 2nd tab
with gr.Tab("Generate Room Style"):
text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
image_output = gr.Image()
btn_generate = gr.Button("Generate Image")
btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
demo.launch()