Arc / arc
sedefiizm's picture
Create arc
ca38051 verified
raw
history blame
3.9 kB
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt
# Veri seti hazırlığı
def load_images_and_texts(image_dir, text_data, img_size=(64, 64)):
"""Görselleri ve metin açıklamalarını yükler."""
images, texts = [], []
for idx, row in text_data.iterrows():
img_path = os.path.join(image_dir, row['File_Name'] + '.png')
if os.path.exists(img_path):
img = load_img(img_path, target_size=img_size)
img_array = img_to_array(img) / 255.0
images.append(img_array)
texts.append(row['BERT_Embeddings'])
return np.array(images), np.array(texts)
# CNN Modeli
def build_cnn_model(image_shape, text_dim):
"""CNN modeli: Görsel ve metin açıklamalarını birleştirerek sınıflandırma yapar."""
text_input = Input(shape=(text_dim,))
img_input = Input(shape=image_shape)
# Görsel kısmı
x_img = Conv2D(32, (3, 3), activation='relu', padding='same')(img_input)
x_img = MaxPooling2D((2, 2))(x_img)
x_img = Conv2D(64, (3, 3), activation='relu', padding='same')(x_img)
x_img = MaxPooling2D((2, 2))(x_img)
x_img = Flatten()(x_img)
# Metin kısmı
x_text = Dense(256, activation='relu')(text_input)
# Görsel ve metin birleşimi
x = Concatenate()([x_img, x_text])
x = Dense(128, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x) # Binary classification
model = Model([img_input, text_input], x, name="CNN_Model")
return model
# Parametreler
epochs = 1000 # 1000 epoch
batch_size = 32
image_shape = (64, 64, 3)
text_dim = 768 # BERT embedding boyutu
# Metin açıklamalarını yükleme
pkl_path = '/content/drive/Othercomputers/Dizüstü Bilgisayarım/Desktop/word_embeddings_dataframe.pkl'
data = pd.read_pickle(pkl_path)
# Görseller ve metin açıklamalarını yükleme
image_dir = '/content/drive/Othercomputers/Dizüstü Bilgisayarım/Desktop/human_annotated_images'
images, texts = load_images_and_texts(image_dir, data)
# Metin açıklamaları boyutunu düzeltme
texts = np.squeeze(texts, axis=1) # (N, 1, 768) -> (N, 768)
# CNN Modeli oluşturma
cnn_model = build_cnn_model(image_shape, text_dim)
# Modeli derleme
cnn_model.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])
# Eğitim döngüsü
def train(epochs, batch_size):
for epoch in range(epochs):
# Gerçek görsellerden örnekleme
idx = np.random.randint(0, images.shape[0], batch_size)
real_images = images[idx]
real_texts = texts[idx]
labels = np.ones((batch_size, 1)) # Gerçek görseller için etiketler
# Eğitim
loss, accuracy = cnn_model.train_on_batch([real_images, real_texts], labels)
# İlerlemeyi yazdırma
if epoch % 10 == 0:
print(f"Epoch {epoch}/{epochs} | Loss: {loss} | Accuracy: {accuracy}")
# Modeli her 100 epoch'ta kaydetme
if epoch % 100 == 0:
cnn_model.save(f'cnn_model_epoch_{epoch}.h5')
# Modeli eğit
train(epochs, batch_size)
# Üretilen örnekleri kaydetme
def generate_and_save_samples(cnn_model, num_samples=5):
idx = np.random.randint(0, images.shape[0], num_samples)
sample_images = images[idx]
sample_texts = texts[idx]
predictions = cnn_model.predict([sample_images, sample_texts])
for i, img in enumerate(sample_images):
plt.imshow(img)
plt.axis('off')
plt.title(f"Prediction: {predictions[i]}")
plt.savefig(f"sample_image_{i}.png")
# Üretilen görselleri kaydetme
generate_and_save_samples(cnn_model)