sk2003 commited on
Commit
0747848
1 Parent(s): 73142ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -6,30 +6,33 @@ from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # Finetuned Resnet-50 model is downloaded
12
  vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
13
 
 
14
  vgg16 = models.vgg16(pretrained=True)
15
  for param in vgg16.parameters():
16
- param.requires_grad = False # freezing parameters
17
 
 
18
  num_classes = 8
19
- vgg16.fc = nn.Linear(vgg16.fc.in_features, num_classes)
20
  vgg16 = vgg16.to(device)
21
 
22
- # Loading the model
23
  checkpoint = torch.load(vgg16_model_path, map_location=device)
24
  vgg16.load_state_dict(checkpoint['model_state_dict'])
25
- vgg16.eval() # setting to evaluation mode to disable batch-norm and dropout layers
26
 
27
- # Fine-tuned Stable Diffusion model
28
  model_id = "sk2003/room-styler"
29
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
30
  pipe.to(device)
31
 
32
- # Prediction function for the ResNet50 model
33
  def predict(image):
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
@@ -39,7 +42,7 @@ def predict(image):
39
  image_tensor = transform(image).unsqueeze(0).to(device)
40
 
41
  with torch.no_grad():
42
- outputs = resnet50(image_tensor)
43
  _, predicted = torch.max(outputs.data, 1)
44
 
45
  classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
@@ -54,9 +57,9 @@ def generate_image(prompt):
54
 
55
  # Gradio interface
56
  with gr.Blocks() as demo:
57
- gr.Markdown("## Room Style Recognition and Generation") # title
58
 
59
- # 1st tab
60
  with gr.Tab("Recognize Room Style"):
61
  image_input = gr.Image(type="pil")
62
  label_output = gr.Textbox()
@@ -71,4 +74,3 @@ with gr.Blocks() as demo:
71
  btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
72
 
73
  demo.launch()
74
-
 
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
9
+ # Set the device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Download the fine-tuned VGG16 model
13
  vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
14
 
15
+ # Load the VGG16 model with pre-trained weights
16
  vgg16 = models.vgg16(pretrained=True)
17
  for param in vgg16.parameters():
18
+ param.requires_grad = False # Freeze parameters
19
 
20
+ # Update the last fully connected layer to match the number of classes
21
  num_classes = 8
22
+ vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
23
  vgg16 = vgg16.to(device)
24
 
25
+ # Load the fine-tuned model state
26
  checkpoint = torch.load(vgg16_model_path, map_location=device)
27
  vgg16.load_state_dict(checkpoint['model_state_dict'])
28
+ vgg16.eval() # Set the model to evaluation mode
29
 
30
+ # Load the fine-tuned Stable Diffusion model
31
  model_id = "sk2003/room-styler"
32
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
33
  pipe.to(device)
34
 
35
+ # Prediction function for the VGG16 model
36
  def predict(image):
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
 
42
  image_tensor = transform(image).unsqueeze(0).to(device)
43
 
44
  with torch.no_grad():
45
+ outputs = vgg16(image_tensor)
46
  _, predicted = torch.max(outputs.data, 1)
47
 
48
  classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
 
57
 
58
  # Gradio interface
59
  with gr.Blocks() as demo:
60
+ gr.Markdown("## Room Style Recognition and Generation") # Title
61
 
62
+ # 1st tab
63
  with gr.Tab("Recognize Room Style"):
64
  image_input = gr.Image(type="pil")
65
  label_output = gr.Textbox()
 
74
  btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
75
 
76
  demo.launch()