Spaces:
Sleeping
Sleeping
import gradio as gr | |
from diffusers import StableDiffusionPipeline | |
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
import torch.nn as nn | |
# Set the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Download the fine-tuned VGG16 model | |
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth") | |
# Load the VGG16 model with pre-trained weights | |
vgg16 = models.vgg16(pretrained=True) | |
for param in vgg16.parameters(): | |
param.requires_grad = False # Freeze parameters | |
# Update the last fully connected layer to match the number of classes | |
num_classes = 8 | |
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes) | |
vgg16 = vgg16.to(device) | |
# Load the fine-tuned model state | |
checkpoint = torch.load(vgg16_model_path, map_location=device) | |
vgg16.load_state_dict(checkpoint['model_state_dict']) | |
vgg16.eval() # Set the model to evaluation mode | |
# Load the fine-tuned Stable Diffusion model | |
model_id = "sk2003/room-styler" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe.to(device) | |
# Prediction function for the VGG16 model | |
def predict(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = vgg16(image_tensor) | |
_, predicted = torch.max(outputs.data, 1) | |
classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"] | |
pred = classes[predicted.item()] | |
return pred | |
# Generation function for the Stable Diffusion model | |
def generate_image(prompt): | |
image = pipe(prompt).images[0] | |
return image | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Room Style Recognition and Generation") # Title | |
# 1st tab | |
with gr.Tab("Recognize Room Style"): | |
image_input = gr.Image(type="pil") | |
label_output = gr.Textbox() | |
btn_predict = gr.Button("Predict Style") | |
btn_predict.click(predict, inputs=image_input, outputs=label_output) | |
# 2nd tab | |
with gr.Tab("Generate Room Style"): | |
text_input = gr.Textbox(placeholder="Enter a prompt for room style...") | |
image_output = gr.Image() | |
btn_generate = gr.Button("Generate Image") | |
btn_generate.click(generate_image, inputs=text_input, outputs=image_output) | |
demo.launch() | |