Update app.py
Browse files
app.py
CHANGED
|
@@ -8,11 +8,16 @@ import pandas as pd
|
|
| 8 |
from datasets import load_dataset
|
| 9 |
from torch.utils.data import DataLoader, Dataset
|
| 10 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
# Load dataset
|
| 13 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
# Preprocess text data
|
| 18 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
@@ -64,11 +69,11 @@ class TextModel(nn.Module):
|
|
| 64 |
|
| 65 |
# Combined model
|
| 66 |
class CombinedModel(nn.Module):
|
| 67 |
-
def __init__(self):
|
| 68 |
super(CombinedModel, self).__init__()
|
| 69 |
self.image_model = ImageModel()
|
| 70 |
self.text_model = TextModel()
|
| 71 |
-
self.fc = nn.Linear(1024,
|
| 72 |
|
| 73 |
def forward(self, image, text):
|
| 74 |
image_features = self.image_model(image)
|
|
@@ -76,8 +81,42 @@ class CombinedModel(nn.Module):
|
|
| 76 |
combined = torch.cat((image_features, text_features), dim=1)
|
| 77 |
return self.fc(combined)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Instantiate model
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# Define predict function
|
| 83 |
def predict(image):
|
|
@@ -93,7 +132,7 @@ def predict(image):
|
|
| 93 |
)
|
| 94 |
output = model(image, text_input)
|
| 95 |
_, indices = torch.topk(output, 5)
|
| 96 |
-
recommended_models = [dataset[
|
| 97 |
return recommended_models
|
| 98 |
|
| 99 |
# Set up Gradio interface
|
|
@@ -105,5 +144,6 @@ interface = gr.Interface(
|
|
| 105 |
description="Upload an AI-generated image to receive model recommendations."
|
| 106 |
)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
| 8 |
from datasets import load_dataset
|
| 9 |
from torch.utils.data import DataLoader, Dataset
|
| 10 |
from sklearn.preprocessing import LabelEncoder
|
| 11 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
| 12 |
+
import seaborn as sns
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import numpy as np
|
| 15 |
|
| 16 |
+
# Load dataset
|
| 17 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
| 18 |
+
|
| 19 |
+
# Filter out entries with None or null Model values
|
| 20 |
+
filtered_dataset = dataset.filter(lambda example: example['Model'] is not None)
|
| 21 |
|
| 22 |
# Preprocess text data
|
| 23 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
| 69 |
|
| 70 |
# Combined model
|
| 71 |
class CombinedModel(nn.Module):
|
| 72 |
+
def __init__(self, num_classes):
|
| 73 |
super(CombinedModel, self).__init__()
|
| 74 |
self.image_model = ImageModel()
|
| 75 |
self.text_model = TextModel()
|
| 76 |
+
self.fc = nn.Linear(1024, num_classes)
|
| 77 |
|
| 78 |
def forward(self, image, text):
|
| 79 |
image_features = self.image_model(image)
|
|
|
|
| 81 |
combined = torch.cat((image_features, text_features), dim=1)
|
| 82 |
return self.fc(combined)
|
| 83 |
|
| 84 |
+
def evaluate_model(model, test_loader, device):
|
| 85 |
+
model.eval()
|
| 86 |
+
all_preds = []
|
| 87 |
+
all_labels = []
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
for images, texts, labels in test_loader:
|
| 91 |
+
images = images.to(device)
|
| 92 |
+
texts = {k: v.to(device) for k, v in texts.items()}
|
| 93 |
+
labels = labels.to(device)
|
| 94 |
+
|
| 95 |
+
outputs = model(images, texts)
|
| 96 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 97 |
+
|
| 98 |
+
all_preds.extend(predicted.cpu().numpy())
|
| 99 |
+
all_labels.extend(labels.cpu().numpy())
|
| 100 |
+
|
| 101 |
+
# Generate confusion matrix
|
| 102 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 103 |
+
|
| 104 |
+
# Plot confusion matrix
|
| 105 |
+
plt.figure(figsize=(10, 8))
|
| 106 |
+
sns.heatmap(cm, annot=True, fmt='d')
|
| 107 |
+
plt.title('Confusion Matrix')
|
| 108 |
+
plt.ylabel('True Label')
|
| 109 |
+
plt.xlabel('Predicted Label')
|
| 110 |
+
plt.savefig('confusion_matrix.png')
|
| 111 |
+
plt.close()
|
| 112 |
+
|
| 113 |
+
# Print classification report
|
| 114 |
+
print(classification_report(all_labels, all_preds))
|
| 115 |
+
|
| 116 |
# Instantiate model
|
| 117 |
+
dataset = CustomDataset(filtered_dataset)
|
| 118 |
+
num_classes = len(np.unique(dataset.labels))
|
| 119 |
+
model = CombinedModel(num_classes)
|
| 120 |
|
| 121 |
# Define predict function
|
| 122 |
def predict(image):
|
|
|
|
| 132 |
)
|
| 133 |
output = model(image, text_input)
|
| 134 |
_, indices = torch.topk(output, 5)
|
| 135 |
+
recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]]
|
| 136 |
return recommended_models
|
| 137 |
|
| 138 |
# Set up Gradio interface
|
|
|
|
| 144 |
description="Upload an AI-generated image to receive model recommendations."
|
| 145 |
)
|
| 146 |
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
# Launch the app
|
| 149 |
+
interface.launch()
|