sk2003 commited on
Commit
cb4264c
1 Parent(s): 24fc650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -9
app.py CHANGED
@@ -1,23 +1,67 @@
1
  import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
 
 
 
 
4
 
5
- # Fine-tuned Stable Diffusion model
6
  model_id = "sk2003/room-styler"
7
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
 
 
 
 
 
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
9
  pipe.to(device)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def generate_image(prompt):
12
  image = pipe(prompt).images[0]
13
  return image
14
 
15
- iface = gr.Interface(
16
- fn=generate_image,
17
- inputs=gr.Textbox(lines=2, placeholder="Enter prompt here..."),
18
- outputs="image",
19
- title="Room Styler",
20
- description="Generate room style images with Stable Diffusion."
21
- )
 
 
 
 
 
 
 
 
22
 
23
- iface.launch()
 
1
  import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ # Fine-tuned Stable Diffusion model from your Hugging Face repository
10
  model_id = "sk2003/room-styler"
11
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
+
13
+ # VGG16 model
14
+ vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
15
+ vgg16 = torch.load(vgg16_model_path)
16
+ vgg16.eval()
17
+
18
+ # Device
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ vgg16.to(device)
21
  pipe.to(device)
22
 
23
+ # Prediction function for the VGG16 model
24
+ def predict_and_show(image):
25
+ transform = transforms.Compose([
26
+ transforms.Resize((224, 224)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
29
+ ])
30
+ image_tensor = transform(image).unsqueeze(0).to(device)
31
+
32
+ with torch.no_grad():
33
+ outputs = vgg16(image_tensor)
34
+ _, predicted = torch.max(outputs.data, 1)
35
+
36
+ class_names = ["Class1", "Class2", "Class3"] # Replace with your actual class names
37
+ predicted_label = class_names[predicted.item()]
38
+
39
+ plt.imshow(image)
40
+ plt.title(f'Predicted: {predicted_label}')
41
+ plt.axis('off')
42
+ plt.show()
43
+
44
+ return predicted_label
45
+
46
+ # Generation function for the Stable Diffusion model
47
  def generate_image(prompt):
48
  image = pipe(prompt).images[0]
49
  return image
50
 
51
+ # Gradio interface
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("## Room Style Recognition and Generation")
54
+
55
+ with gr.Tab("Recognize Room Style"):
56
+ image_input = gr.Image(type="pil")
57
+ label_output = gr.Textbox()
58
+ btn_predict = gr.Button("Predict Style")
59
+ btn_predict.click(predict_and_show, inputs=image_input, outputs=label_output)
60
+
61
+ with gr.Tab("Generate Room Style"):
62
+ text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
63
+ image_output = gr.Image()
64
+ btn_generate = gr.Button("Generate Image")
65
+ btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
66
 
67
+ demo.launch()