sedefiizm commited on
Commit
ca38051
1 Parent(s): 034cf61

Create arc

Browse files
Files changed (1) hide show
  1. arc +109 -0
arc ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Concatenate, Input
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.optimizers import Adam
8
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Veri seti hazırlığı
12
+ def load_images_and_texts(image_dir, text_data, img_size=(64, 64)):
13
+ """Görselleri ve metin açıklamalarını yükler."""
14
+ images, texts = [], []
15
+ for idx, row in text_data.iterrows():
16
+ img_path = os.path.join(image_dir, row['File_Name'] + '.png')
17
+ if os.path.exists(img_path):
18
+ img = load_img(img_path, target_size=img_size)
19
+ img_array = img_to_array(img) / 255.0
20
+ images.append(img_array)
21
+ texts.append(row['BERT_Embeddings'])
22
+ return np.array(images), np.array(texts)
23
+
24
+ # CNN Modeli
25
+ def build_cnn_model(image_shape, text_dim):
26
+ """CNN modeli: Görsel ve metin açıklamalarını birleştirerek sınıflandırma yapar."""
27
+ text_input = Input(shape=(text_dim,))
28
+ img_input = Input(shape=image_shape)
29
+
30
+ # Görsel kısmı
31
+ x_img = Conv2D(32, (3, 3), activation='relu', padding='same')(img_input)
32
+ x_img = MaxPooling2D((2, 2))(x_img)
33
+ x_img = Conv2D(64, (3, 3), activation='relu', padding='same')(x_img)
34
+ x_img = MaxPooling2D((2, 2))(x_img)
35
+ x_img = Flatten()(x_img)
36
+
37
+ # Metin kısmı
38
+ x_text = Dense(256, activation='relu')(text_input)
39
+
40
+ # Görsel ve metin birleşimi
41
+ x = Concatenate()([x_img, x_text])
42
+ x = Dense(128, activation='relu')(x)
43
+ x = Dense(1, activation='sigmoid')(x) # Binary classification
44
+
45
+ model = Model([img_input, text_input], x, name="CNN_Model")
46
+ return model
47
+
48
+ # Parametreler
49
+ epochs = 1000 # 1000 epoch
50
+ batch_size = 32
51
+ image_shape = (64, 64, 3)
52
+ text_dim = 768 # BERT embedding boyutu
53
+
54
+ # Metin açıklamalarını yükleme
55
+ pkl_path = '/content/drive/Othercomputers/Dizüstü Bilgisayarım/Desktop/word_embeddings_dataframe.pkl'
56
+ data = pd.read_pickle(pkl_path)
57
+
58
+ # Görseller ve metin açıklamalarını yükleme
59
+ image_dir = '/content/drive/Othercomputers/Dizüstü Bilgisayarım/Desktop/human_annotated_images'
60
+ images, texts = load_images_and_texts(image_dir, data)
61
+
62
+ # Metin açıklamaları boyutunu düzeltme
63
+ texts = np.squeeze(texts, axis=1) # (N, 1, 768) -> (N, 768)
64
+
65
+ # CNN Modeli oluşturma
66
+ cnn_model = build_cnn_model(image_shape, text_dim)
67
+
68
+ # Modeli derleme
69
+ cnn_model.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])
70
+
71
+ # Eğitim döngüsü
72
+ def train(epochs, batch_size):
73
+ for epoch in range(epochs):
74
+ # Gerçek görsellerden örnekleme
75
+ idx = np.random.randint(0, images.shape[0], batch_size)
76
+ real_images = images[idx]
77
+ real_texts = texts[idx]
78
+ labels = np.ones((batch_size, 1)) # Gerçek görseller için etiketler
79
+
80
+ # Eğitim
81
+ loss, accuracy = cnn_model.train_on_batch([real_images, real_texts], labels)
82
+
83
+ # İlerlemeyi yazdırma
84
+ if epoch % 10 == 0:
85
+ print(f"Epoch {epoch}/{epochs} | Loss: {loss} | Accuracy: {accuracy}")
86
+
87
+ # Modeli her 100 epoch'ta kaydetme
88
+ if epoch % 100 == 0:
89
+ cnn_model.save(f'cnn_model_epoch_{epoch}.h5')
90
+
91
+ # Modeli eğit
92
+ train(epochs, batch_size)
93
+
94
+ # Üretilen örnekleri kaydetme
95
+ def generate_and_save_samples(cnn_model, num_samples=5):
96
+ idx = np.random.randint(0, images.shape[0], num_samples)
97
+ sample_images = images[idx]
98
+ sample_texts = texts[idx]
99
+
100
+ predictions = cnn_model.predict([sample_images, sample_texts])
101
+
102
+ for i, img in enumerate(sample_images):
103
+ plt.imshow(img)
104
+ plt.axis('off')
105
+ plt.title(f"Prediction: {predictions[i]}")
106
+ plt.savefig(f"sample_image_{i}.png")
107
+
108
+ # Üretilen görselleri kaydetme
109
+ generate_and_save_samples(cnn_model)