Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,16 @@ import pandas as pd
|
|
8 |
from datasets import load_dataset
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
# Load dataset
|
13 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
14 |
-
|
15 |
-
|
|
|
16 |
|
17 |
# Preprocess text data
|
18 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
@@ -64,11 +69,11 @@ class TextModel(nn.Module):
|
|
64 |
|
65 |
# Combined model
|
66 |
class CombinedModel(nn.Module):
|
67 |
-
def __init__(self):
|
68 |
super(CombinedModel, self).__init__()
|
69 |
self.image_model = ImageModel()
|
70 |
self.text_model = TextModel()
|
71 |
-
self.fc = nn.Linear(1024,
|
72 |
|
73 |
def forward(self, image, text):
|
74 |
image_features = self.image_model(image)
|
@@ -76,8 +81,42 @@ class CombinedModel(nn.Module):
|
|
76 |
combined = torch.cat((image_features, text_features), dim=1)
|
77 |
return self.fc(combined)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# Instantiate model
|
80 |
-
|
|
|
|
|
81 |
|
82 |
# Define predict function
|
83 |
def predict(image):
|
@@ -93,7 +132,7 @@ def predict(image):
|
|
93 |
)
|
94 |
output = model(image, text_input)
|
95 |
_, indices = torch.topk(output, 5)
|
96 |
-
recommended_models = [dataset[
|
97 |
return recommended_models
|
98 |
|
99 |
# Set up Gradio interface
|
@@ -105,5 +144,6 @@ interface = gr.Interface(
|
|
105 |
description="Upload an AI-generated image to receive model recommendations."
|
106 |
)
|
107 |
|
108 |
-
|
109 |
-
|
|
|
|
8 |
from datasets import load_dataset
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from sklearn.preprocessing import LabelEncoder
|
11 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
12 |
+
import seaborn as sns
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
|
16 |
+
# Load dataset
|
17 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
18 |
+
|
19 |
+
# Filter out entries with None or null Model values
|
20 |
+
filtered_dataset = dataset.filter(lambda example: example['Model'] is not None)
|
21 |
|
22 |
# Preprocess text data
|
23 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
69 |
|
70 |
# Combined model
|
71 |
class CombinedModel(nn.Module):
|
72 |
+
def __init__(self, num_classes):
|
73 |
super(CombinedModel, self).__init__()
|
74 |
self.image_model = ImageModel()
|
75 |
self.text_model = TextModel()
|
76 |
+
self.fc = nn.Linear(1024, num_classes)
|
77 |
|
78 |
def forward(self, image, text):
|
79 |
image_features = self.image_model(image)
|
|
|
81 |
combined = torch.cat((image_features, text_features), dim=1)
|
82 |
return self.fc(combined)
|
83 |
|
84 |
+
def evaluate_model(model, test_loader, device):
|
85 |
+
model.eval()
|
86 |
+
all_preds = []
|
87 |
+
all_labels = []
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
for images, texts, labels in test_loader:
|
91 |
+
images = images.to(device)
|
92 |
+
texts = {k: v.to(device) for k, v in texts.items()}
|
93 |
+
labels = labels.to(device)
|
94 |
+
|
95 |
+
outputs = model(images, texts)
|
96 |
+
_, predicted = torch.max(outputs.data, 1)
|
97 |
+
|
98 |
+
all_preds.extend(predicted.cpu().numpy())
|
99 |
+
all_labels.extend(labels.cpu().numpy())
|
100 |
+
|
101 |
+
# Generate confusion matrix
|
102 |
+
cm = confusion_matrix(all_labels, all_preds)
|
103 |
+
|
104 |
+
# Plot confusion matrix
|
105 |
+
plt.figure(figsize=(10, 8))
|
106 |
+
sns.heatmap(cm, annot=True, fmt='d')
|
107 |
+
plt.title('Confusion Matrix')
|
108 |
+
plt.ylabel('True Label')
|
109 |
+
plt.xlabel('Predicted Label')
|
110 |
+
plt.savefig('confusion_matrix.png')
|
111 |
+
plt.close()
|
112 |
+
|
113 |
+
# Print classification report
|
114 |
+
print(classification_report(all_labels, all_preds))
|
115 |
+
|
116 |
# Instantiate model
|
117 |
+
dataset = CustomDataset(filtered_dataset)
|
118 |
+
num_classes = len(np.unique(dataset.labels))
|
119 |
+
model = CombinedModel(num_classes)
|
120 |
|
121 |
# Define predict function
|
122 |
def predict(image):
|
|
|
132 |
)
|
133 |
output = model(image, text_input)
|
134 |
_, indices = torch.topk(output, 5)
|
135 |
+
recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]]
|
136 |
return recommended_models
|
137 |
|
138 |
# Set up Gradio interface
|
|
|
144 |
description="Upload an AI-generated image to receive model recommendations."
|
145 |
)
|
146 |
|
147 |
+
if __name__ == "__main__":
|
148 |
+
# Launch the app
|
149 |
+
interface.launch()
|