Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,30 +6,33 @@ from PIL import Image
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import torch.nn as nn
|
8 |
|
|
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
-
#
|
12 |
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
|
13 |
|
|
|
14 |
vgg16 = models.vgg16(pretrained=True)
|
15 |
for param in vgg16.parameters():
|
16 |
-
param.requires_grad = False
|
17 |
|
|
|
18 |
num_classes = 8
|
19 |
-
vgg16.
|
20 |
vgg16 = vgg16.to(device)
|
21 |
|
22 |
-
#
|
23 |
checkpoint = torch.load(vgg16_model_path, map_location=device)
|
24 |
vgg16.load_state_dict(checkpoint['model_state_dict'])
|
25 |
-
vgg16.eval()
|
26 |
|
27 |
-
#
|
28 |
model_id = "sk2003/room-styler"
|
29 |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
30 |
pipe.to(device)
|
31 |
|
32 |
-
# Prediction function for the
|
33 |
def predict(image):
|
34 |
transform = transforms.Compose([
|
35 |
transforms.Resize((224, 224)),
|
@@ -39,7 +42,7 @@ def predict(image):
|
|
39 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
40 |
|
41 |
with torch.no_grad():
|
42 |
-
outputs =
|
43 |
_, predicted = torch.max(outputs.data, 1)
|
44 |
|
45 |
classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
|
@@ -54,9 +57,9 @@ def generate_image(prompt):
|
|
54 |
|
55 |
# Gradio interface
|
56 |
with gr.Blocks() as demo:
|
57 |
-
gr.Markdown("## Room Style Recognition and Generation")
|
58 |
|
59 |
-
|
60 |
with gr.Tab("Recognize Room Style"):
|
61 |
image_input = gr.Image(type="pil")
|
62 |
label_output = gr.Textbox()
|
@@ -71,4 +74,3 @@ with gr.Blocks() as demo:
|
|
71 |
btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
|
72 |
|
73 |
demo.launch()
|
74 |
-
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import torch.nn as nn
|
8 |
|
9 |
+
# Set the device
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
+
# Download the fine-tuned VGG16 model
|
13 |
vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
|
14 |
|
15 |
+
# Load the VGG16 model with pre-trained weights
|
16 |
vgg16 = models.vgg16(pretrained=True)
|
17 |
for param in vgg16.parameters():
|
18 |
+
param.requires_grad = False # Freeze parameters
|
19 |
|
20 |
+
# Update the last fully connected layer to match the number of classes
|
21 |
num_classes = 8
|
22 |
+
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
|
23 |
vgg16 = vgg16.to(device)
|
24 |
|
25 |
+
# Load the fine-tuned model state
|
26 |
checkpoint = torch.load(vgg16_model_path, map_location=device)
|
27 |
vgg16.load_state_dict(checkpoint['model_state_dict'])
|
28 |
+
vgg16.eval() # Set the model to evaluation mode
|
29 |
|
30 |
+
# Load the fine-tuned Stable Diffusion model
|
31 |
model_id = "sk2003/room-styler"
|
32 |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
33 |
pipe.to(device)
|
34 |
|
35 |
+
# Prediction function for the VGG16 model
|
36 |
def predict(image):
|
37 |
transform = transforms.Compose([
|
38 |
transforms.Resize((224, 224)),
|
|
|
42 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
43 |
|
44 |
with torch.no_grad():
|
45 |
+
outputs = vgg16(image_tensor)
|
46 |
_, predicted = torch.max(outputs.data, 1)
|
47 |
|
48 |
classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
|
|
|
57 |
|
58 |
# Gradio interface
|
59 |
with gr.Blocks() as demo:
|
60 |
+
gr.Markdown("## Room Style Recognition and Generation") # Title
|
61 |
|
62 |
+
# 1st tab
|
63 |
with gr.Tab("Recognize Room Style"):
|
64 |
image_input = gr.Image(type="pil")
|
65 |
label_output = gr.Textbox()
|
|
|
74 |
btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
|
75 |
|
76 |
demo.launch()
|
|