|
import json |
|
from os import close |
|
from pathlib import Path |
|
from azure.cosmos import CosmosClient, PartitionKey, exceptions |
|
from transformers import DistilBertTokenizerFast |
|
import torch |
|
|
|
|
|
class Model: |
|
|
|
def __init__(self) -> None: |
|
self.endPoint = "https://productdevelopmentstorage.documents.azure.com:443/" |
|
self.primaryKey = "nVds9dPOkPuKu8RyWqigA1DIah4SVZtl1DIM0zDuRKd95an04QC0qv9TQIgrdtgluZo7Z0HXACFQgKgOQEAx1g==" |
|
self.client = CosmosClient(self.endPoint, self.primaryKey) |
|
self.tokenizer = None |
|
|
|
def GetData(self, type): |
|
database = self.client.get_database_client("squadstorage") |
|
container = database.get_container_client(type) |
|
item_list = list(container.read_all_items(max_item_count=10)) |
|
return item_list |
|
|
|
def ArrangeData(self, type): |
|
squad_dict = self.GetData(type) |
|
|
|
contexts = [] |
|
questions = [] |
|
answers = [] |
|
|
|
for i in squad_dict: |
|
contexts.append(i["context"]) |
|
questions.append(i["question"]) |
|
answers.append(i["answers"]) |
|
|
|
return contexts, questions, answers |
|
|
|
def add_end_idx(self, answers, contexts): |
|
for answer, context in zip(answers, contexts): |
|
gold_text = answer['text'][0] |
|
start_idx = answer['answer_start'][0] |
|
end_idx = start_idx + len(gold_text) |
|
|
|
if context[start_idx:end_idx] == gold_text: |
|
answer['answer_end'] = end_idx |
|
elif context[start_idx-1:end_idx-1] == gold_text: |
|
answer['answer_start'] = start_idx - 1 |
|
answer['answer_end'] = end_idx - 1 |
|
elif context[start_idx-2:end_idx-2] == gold_text: |
|
answer['answer_start'] = start_idx - 2 |
|
answer['answer_end'] = end_idx - 2 |
|
|
|
return answers, contexts |
|
|
|
def Tokenizer(self, train_contexts, train_questions, val_contexts, val_questions): |
|
self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') |
|
|
|
train_encodings = self.tokenizer(train_contexts, train_questions, truncation=True, padding=True) |
|
val_encodings = self.tokenizer(val_contexts, val_questions, truncation=True, padding=True) |
|
|
|
return train_encodings, val_encodings |
|
|
|
|
|
def add_token_positions(self, encodings, answers): |
|
start_positions = [] |
|
end_positions = [] |
|
for i in range(len(answers)): |
|
start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'][0])) |
|
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1)) |
|
|
|
|
|
|
|
if start_positions[-1] is None: |
|
start_positions[-1] = self.tokenizer.model_max_length |
|
if end_positions[-1] is None: |
|
end_positions[-1] = self.tokenizer.model_max_length |
|
|
|
encodings.update({'start_positions': start_positions, 'end_positions': end_positions}) |
|
return encodings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SquadDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings): |
|
self.encodings = encodings |
|
|
|
def __getitem__(self, idx): |
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
|
|
def __len__(self): |
|
return len(self.encodings.input_ids) |
|
|
|
|