|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
from transformers import BertTokenizer, BertModel |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader, Dataset |
|
from sklearn.preprocessing import LabelEncoder |
|
from sklearn.metrics import confusion_matrix, classification_report |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') |
|
|
|
|
|
filtered_dataset = dataset.filter(lambda example: example['Model'] is not None) |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
self.label_encoder = LabelEncoder() |
|
self.labels = self.label_encoder.fit_transform(dataset['Model']) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
image = self.transform(self.dataset[idx]['image']) |
|
text = tokenizer( |
|
self.dataset[idx]['prompt'], |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
label = self.labels[idx] |
|
return image, text, label |
|
|
|
|
|
class ImageModel(nn.Module): |
|
def __init__(self): |
|
super(ImageModel, self).__init__() |
|
self.model = models.resnet18(pretrained=True) |
|
self.model.fc = nn.Linear(self.model.fc.in_features, 512) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class TextModel(nn.Module): |
|
def __init__(self): |
|
super(TextModel, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc = nn.Linear(768, 512) |
|
|
|
def forward(self, x): |
|
output = self.bert(**x) |
|
return self.fc(output.pooler_output) |
|
|
|
|
|
class CombinedModel(nn.Module): |
|
def __init__(self, num_classes): |
|
super(CombinedModel, self).__init__() |
|
self.image_model = ImageModel() |
|
self.text_model = TextModel() |
|
self.fc = nn.Linear(1024, num_classes) |
|
|
|
def forward(self, image, text): |
|
image_features = self.image_model(image) |
|
text_features = self.text_model(text) |
|
combined = torch.cat((image_features, text_features), dim=1) |
|
return self.fc(combined) |
|
|
|
def evaluate_model(model, test_loader, device): |
|
model.eval() |
|
all_preds = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for images, texts, labels in test_loader: |
|
images = images.to(device) |
|
texts = {k: v.to(device) for k, v in texts.items()} |
|
labels = labels.to(device) |
|
|
|
outputs = model(images, texts) |
|
_, predicted = torch.max(outputs.data, 1) |
|
|
|
all_preds.extend(predicted.cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(cm, annot=True, fmt='d') |
|
plt.title('Confusion Matrix') |
|
plt.ylabel('True Label') |
|
plt.xlabel('Predicted Label') |
|
plt.savefig('confusion_matrix.png') |
|
plt.close() |
|
|
|
|
|
print(classification_report(all_labels, all_preds)) |
|
|
|
|
|
dataset = CustomDataset(filtered_dataset) |
|
num_classes = len(np.unique(dataset.labels)) |
|
model = CombinedModel(num_classes) |
|
|
|
|
|
def predict(image): |
|
model.eval() |
|
with torch.no_grad(): |
|
image = transforms.ToTensor()(image).unsqueeze(0) |
|
image = transforms.Resize((224, 224))(image) |
|
text_input = tokenizer( |
|
"Sample prompt", |
|
return_tensors='pt', |
|
padding=True, |
|
truncation=True |
|
) |
|
output = model(image, text_input) |
|
_, indices = torch.topk(output, 5) |
|
recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]] |
|
return recommended_models |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(label="Recommended Models"), |
|
title="AI Image Model Recommender", |
|
description="Upload an AI-generated image to receive model recommendations." |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
interface.launch() |