makaleChatbotu / model.py
yonkasoft's picture
Upload 4 files
b8522d2 verified
raw
history blame
No virus
6.74 kB
from datasets import load_dataset
import pandas as pd
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig
from pymongo import MongoClient
class Database:
@staticmethod
def get_mongodb():
# MongoDB bağlantı bilgilerini döndürecek şekilde tanımlanmalıdır.
return 'mongodb://localhost:27017/', 'yeniDatabase', 'train'
@staticmethod
def get_input_texts():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Sorguyu tanımlama
query = {"Prompt": {"$exists": True}}
# Sorguyu çalıştırma ve dökümanları çekme
cursor = collection.find(query, {"Prompt": 1, "_id": 0})
# Cursor'ı döküman listesine dönüştürme
input_texts_from_db = list(cursor)
# Input text'leri döndürme
return input_texts_from_db
@staticmethod
def get_output_texts():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Sorguyu tanımlama
query = {"Response": {"$exists": True}}
# Sorguyu çalıştırma ve dökümanları çekme
cursor = collection.find(query, {"Response": 1, "_id": 0})
# Cursor'ı döküman listesine dönüştürme
output_texts_from_db = list(cursor)
# Input text'leri döndürme
return output_texts_from_db
@staticmethod
def get_average_prompt_token_length():
# MongoDB bağlantı bilgilerini alma
mongo_url, db_name, collection_name = Database.get_mongodb()
# MongoDB'ye bağlanma
client = MongoClient(mongo_url)
db = client[db_name]
collection = db[collection_name]
# Tüm dökümanları çekme ve 'prompt_token_length' alanını alma
docs = collection.find({}, {'Prompt_token_length': 1})
# 'prompt_token_length' değerlerini toplama ve sayma
total_length = 0
count = 0
for doc in docs:
if 'Prompt_token_length' in doc:
total_length += doc['Prompt_token_length']
count += 1
# Ortalama hesaplama
average_length = total_length / count if count > 0 else 0
return int(average_length)
# Tokenizer ve Modeli yükleme
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Encode işlemi
def tokenize_and_encode(train_df,doc):
input_texts_from_db = Database.get_input_texts()
output_texts_from_db= Database.get_output_texts()
input_texts = [doc["Prompt"] for doc in input_texts_from_db]
output_texts= [doc["Response"] for doc in output_texts_from_db]
encoded = tokenizer.batch_encode_plus(
#doc['Prompt'].tolist(),
#text_pair= doc['Response'].tolist(),
input_texts,
output_texts,
padding=True,
truncation=True,
max_length=100,
return_attention_mask=True,
return_tensors='pt'
)
return encoded
encoded_data=tokenize_and_encode()
class QA:
#buradaki verilerin değeri değiştirilmeli
def __init__(self, model_path: str):
self.max_seq_length = 384
self.doc_stride = 128
self.do_lower_case = False
self.max_query_length = 30
self.n_best_size = 3
self.max_answer_length = 30
self.version_2_with_negative = False
self.model, self.tokenizer = self.load_model(model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
self.model.eval()
def load_model(self, model_path: str, do_lower_case=False):
config = BertConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=do_lower_case)
model = BertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
return model, tokenizer
def extract_features_from_dataset(self, train_df):
def get_max_length(examples):
return {
'max_seq_length': max(len(e) for e in examples),
'max_query_length': max(len(q) for q in examples)
}
# Örnek bir kullanım
features = get_max_length(train_df)
return features
# Ortalama prompt token uzunluğunu al ve yazdır
average_length = Database.get_average_prompt_token_length()
print(f"Ortalama prompt token uzunluğu: {average_length}")
# QA sınıfını oluştur
qa = QA(model_path='bert-base-uncased')
#tensor veri setini koda entegre etme
"""# Tensor veri kümesi oluşturma
input_ids = encoded_data['input_ids']
attention_mask = encoded_data['attention_mask']
token_type_ids = encoded_data['token_type_ids']
labels = torch.tensor(data['Response'].tolist()) # Cevapları etiket olarak kullanın
# TensorDataset oluşturma
dataset = TensorDataset(input_ids, attention_mask, token_type_ids, labels)
# DataLoader oluşturma
batch_size = 16
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset),
batch_size=batch_size
)"""
#modelin için epoch sayısının tanımlaması
"""# Eğitim için optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
# Eğitim döngüsü
model.train()
for epoch in range(3): # Örnek olarak 3 epoch
for batch in dataloader:
input_ids, attention_mask, token_type_ids, labels = [t.to(device) for t in batch]
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, start_positions=labels, end_positions=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} loss: {loss.item()}")"""
#sonuçların sınıflandırılması
"""# Modeli değerlendirme aşamasına getirme
model.eval()
# Örnek tahmin
with torch.no_grad():
for batch in dataloader:
input_ids, attention_mask, token_type_ids, _ = [t.to(device) for t in batch]
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# Çıktıları kullanarak başlık, alt başlık ve anahtar kelimeler belirleyebilirsiniz
"""