import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models from transformers import BertTokenizer, BertModel import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from sklearn.preprocessing import LabelEncoder from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns import matplotlib.pyplot as plt import numpy as np # Load dataset dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') # Filter out entries with None or null Model values filtered_dataset = dataset.filter(lambda example: example['Model'] is not None) # Preprocess text data tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') class CustomDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(dataset['Model']) def __len__(self): return len(self.dataset) def __getitem__(self, idx): image = self.transform(self.dataset[idx]['image']) text = tokenizer( self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt' ) label = self.labels[idx] return image, text, label # Define CNN for image processing class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): return self.model(x) # Define MLP for text processing class TextModel(nn.Module): def __init__(self): super(TextModel, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(768, 512) def forward(self, x): output = self.bert(**x) return self.fc(output.pooler_output) # Combined model class CombinedModel(nn.Module): def __init__(self, num_classes): super(CombinedModel, self).__init__() self.image_model = ImageModel() self.text_model = TextModel() self.fc = nn.Linear(1024, num_classes) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) return self.fc(combined) def evaluate_model(model, test_loader, device): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, texts, labels in test_loader: images = images.to(device) texts = {k: v.to(device) for k, v in texts.items()} labels = labels.to(device) outputs = model(images, texts) _, predicted = torch.max(outputs.data, 1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # Generate confusion matrix cm = confusion_matrix(all_labels, all_preds) # Plot confusion matrix plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d') plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.savefig('confusion_matrix.png') plt.close() # Print classification report print(classification_report(all_labels, all_preds)) # Instantiate model dataset = CustomDataset(filtered_dataset) num_classes = len(np.unique(dataset.labels)) model = CombinedModel(num_classes) # Define predict function def predict(image): model.eval() with torch.no_grad(): image = transforms.ToTensor()(image).unsqueeze(0) image = transforms.Resize((224, 224))(image) text_input = tokenizer( "Sample prompt", return_tensors='pt', padding=True, truncation=True ) output = model(image, text_input) _, indices = torch.topk(output, 5) recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]] return recommended_models # Set up Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Recommended Models"), title="AI Image Model Recommender", description="Upload an AI-generated image to receive model recommendations." ) if __name__ == "__main__": # Launch the app interface.launch()