Spaces:
Sleeping
Sleeping
File size: 2,602 Bytes
485f4da 63dbd6c 485f4da d1eee78 cb4264c d1eee78 485f4da 0747848 d1eee78 0747848 73142ed 0747848 73142ed 0747848 d1eee78 0747848 10cb574 0747848 73142ed d1eee78 0747848 73142ed 0747848 cb4264c 0747848 d1eee78 24fc650 485f4da 0747848 73142ed cb4264c 0747848 cb4264c 73142ed cb4264c 73142ed cb4264c 63dbd6c 485f4da cb4264c 0747848 cb4264c 0747848 cb4264c 73142ed cb4264c 73142ed cb4264c 485f4da cb4264c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the fine-tuned VGG16 model
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
# Load the VGG16 model with pre-trained weights
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
param.requires_grad = False # Freeze parameters
# Update the last fully connected layer to match the number of classes
num_classes = 8
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
vgg16 = vgg16.to(device)
# Load the fine-tuned model state
checkpoint = torch.load(vgg16_model_path, map_location=device)
vgg16.load_state_dict(checkpoint['model_state_dict'])
vgg16.eval() # Set the model to evaluation mode
# Load the fine-tuned Stable Diffusion model
model_id = "sk2003/room-styler"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)
# Prediction function for the VGG16 model
def predict(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = vgg16(image_tensor)
_, predicted = torch.max(outputs.data, 1)
classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
pred = classes[predicted.item()]
return pred
# Generation function for the Stable Diffusion model
def generate_image(prompt):
image = pipe(prompt).images[0]
return image
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Room Style Recognition and Generation") # Title
# 1st tab
with gr.Tab("Recognize Room Style"):
image_input = gr.Image(type="pil")
label_output = gr.Textbox()
btn_predict = gr.Button("Predict Style")
btn_predict.click(predict, inputs=image_input, outputs=label_output)
# 2nd tab
with gr.Tab("Generate Room Style"):
text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
image_output = gr.Image()
btn_generate = gr.Button("Generate Image")
btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
demo.launch()
|