sk2003 commited on
Commit
d1eee78
1 Parent(s): 4cb93ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -1,23 +1,32 @@
1
  import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
- from torchvision import transforms
5
  from PIL import Image
6
- import matplotlib.pyplot as plt
7
  from huggingface_hub import hf_hub_download
 
 
8
 
9
- # Fine-tuned Stable Diffusion model from your Hugging Face repository
10
- model_id = "sk2003/room-styler"
11
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
-
13
- # VGG16 model
14
  vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
15
- vgg16 = torch.load(vgg16_model_path)
 
 
 
 
 
 
 
 
 
 
 
16
  vgg16.eval()
17
 
18
- # Device
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- vgg16.to(device)
21
  pipe.to(device)
22
 
23
  # Prediction function for the VGG16 model
@@ -33,14 +42,9 @@ def predict_and_show(image):
33
  outputs = vgg16(image_tensor)
34
  _, predicted = torch.max(outputs.data, 1)
35
 
36
- class_names = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
37
  predicted_label = class_names[predicted.item()]
38
 
39
- plt.imshow(image)
40
- plt.title(f'Predicted: {predicted_label}')
41
- plt.axis('off')
42
- plt.show()
43
-
44
  return predicted_label
45
 
46
  # Generation function for the Stable Diffusion model
 
1
  import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
+ from torchvision import models, transforms
5
  from PIL import Image
 
6
  from huggingface_hub import hf_hub_download
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
 
10
+ # LoadING the VGG16 model
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
12
  vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
13
+
14
+ vgg16 = models.vgg16(pretrained=True)
15
+ for param in vgg16.parameters():
16
+ param.requires_grad = False
17
+
18
+ num_classes = 8
19
+ vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
20
+ vgg16 = vgg16.to(device)
21
+
22
+ # Loading the saved state dict
23
+ checkpoint = torch.load(vgg16_model_path, map_location=device)
24
+ vgg16.load_state_dict(checkpoint['model_state_dict'])
25
  vgg16.eval()
26
 
27
+ # Fine-tuned Stable Diffusion model from your Hugging Face repository
28
+ model_id = "sk2003/room-styler"
29
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
30
  pipe.to(device)
31
 
32
  # Prediction function for the VGG16 model
 
42
  outputs = vgg16(image_tensor)
43
  _, predicted = torch.max(outputs.data, 1)
44
 
45
+ class_names = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
46
  predicted_label = class_names[predicted.item()]
47
 
 
 
 
 
 
48
  return predicted_label
49
 
50
  # Generation function for the Stable Diffusion model