bgaspra commited on
Commit
4ab8df2
·
verified ·
1 Parent(s): 07d348b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -9
app.py CHANGED
@@ -8,11 +8,16 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
11
 
12
- # Load dataset and filter out null/none values
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
- # Filter out entries where Model is None or empty
15
- dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
 
16
 
17
  # Preprocess text data
18
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -64,11 +69,11 @@ class TextModel(nn.Module):
64
 
65
  # Combined model
66
  class CombinedModel(nn.Module):
67
- def __init__(self):
68
  super(CombinedModel, self).__init__()
69
  self.image_model = ImageModel()
70
  self.text_model = TextModel()
71
- self.fc = nn.Linear(1024, len(dataset['Model']))
72
 
73
  def forward(self, image, text):
74
  image_features = self.image_model(image)
@@ -76,8 +81,42 @@ class CombinedModel(nn.Module):
76
  combined = torch.cat((image_features, text_features), dim=1)
77
  return self.fc(combined)
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Instantiate model
80
- model = CombinedModel()
 
 
81
 
82
  # Define predict function
83
  def predict(image):
@@ -93,7 +132,7 @@ def predict(image):
93
  )
94
  output = model(image, text_input)
95
  _, indices = torch.topk(output, 5)
96
- recommended_models = [dataset['Model'][i] for i in indices[0]]
97
  return recommended_models
98
 
99
  # Set up Gradio interface
@@ -105,5 +144,6 @@ interface = gr.Interface(
105
  description="Upload an AI-generated image to receive model recommendations."
106
  )
107
 
108
- # Launch the app
109
- interface.launch()
 
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
+ from sklearn.metrics import confusion_matrix, classification_report
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
 
16
+ # Load dataset
17
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
18
+
19
+ # Filter out entries with None or null Model values
20
+ filtered_dataset = dataset.filter(lambda example: example['Model'] is not None)
21
 
22
  # Preprocess text data
23
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
69
 
70
  # Combined model
71
  class CombinedModel(nn.Module):
72
+ def __init__(self, num_classes):
73
  super(CombinedModel, self).__init__()
74
  self.image_model = ImageModel()
75
  self.text_model = TextModel()
76
+ self.fc = nn.Linear(1024, num_classes)
77
 
78
  def forward(self, image, text):
79
  image_features = self.image_model(image)
 
81
  combined = torch.cat((image_features, text_features), dim=1)
82
  return self.fc(combined)
83
 
84
+ def evaluate_model(model, test_loader, device):
85
+ model.eval()
86
+ all_preds = []
87
+ all_labels = []
88
+
89
+ with torch.no_grad():
90
+ for images, texts, labels in test_loader:
91
+ images = images.to(device)
92
+ texts = {k: v.to(device) for k, v in texts.items()}
93
+ labels = labels.to(device)
94
+
95
+ outputs = model(images, texts)
96
+ _, predicted = torch.max(outputs.data, 1)
97
+
98
+ all_preds.extend(predicted.cpu().numpy())
99
+ all_labels.extend(labels.cpu().numpy())
100
+
101
+ # Generate confusion matrix
102
+ cm = confusion_matrix(all_labels, all_preds)
103
+
104
+ # Plot confusion matrix
105
+ plt.figure(figsize=(10, 8))
106
+ sns.heatmap(cm, annot=True, fmt='d')
107
+ plt.title('Confusion Matrix')
108
+ plt.ylabel('True Label')
109
+ plt.xlabel('Predicted Label')
110
+ plt.savefig('confusion_matrix.png')
111
+ plt.close()
112
+
113
+ # Print classification report
114
+ print(classification_report(all_labels, all_preds))
115
+
116
  # Instantiate model
117
+ dataset = CustomDataset(filtered_dataset)
118
+ num_classes = len(np.unique(dataset.labels))
119
+ model = CombinedModel(num_classes)
120
 
121
  # Define predict function
122
  def predict(image):
 
132
  )
133
  output = model(image, text_input)
134
  _, indices = torch.topk(output, 5)
135
+ recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]]
136
  return recommended_models
137
 
138
  # Set up Gradio interface
 
144
  description="Upload an AI-generated image to receive model recommendations."
145
  )
146
 
147
+ if __name__ == "__main__":
148
+ # Launch the app
149
+ interface.launch()