QA / main.py
Ateeb's picture
First version of the your-model-name model and tokenizer.
8f46b8c
raw
history blame
2.06 kB
from preprocess import Model, SquadDataset
from transformers import DistilBertForQuestionAnswering
from torch.utils.data import DataLoader
from transformers import AdamW
import torch
import subprocess
data = Model()
train_contexts, train_questions, train_answers = data.ArrangeData("livecheckcontainer")
val_contexts, val_questions, val_answers = data.ArrangeData("livecheckcontainer")
print(train_answers)
train_answers, train_contexts = data.add_end_idx(train_answers, train_contexts)
val_answers, val_contexts = data.add_end_idx(val_answers, val_contexts)
train_encodings, val_encodings = data.Tokenizer(train_contexts, train_questions, val_contexts, val_questions)
train_encodings = data.add_token_positions(train_encodings, train_answers)
val_encodings = data.add_token_positions(val_encodings, val_answers)
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
optim = AdamW(model.parameters(), lr=5e-5)
for epoch in range(2):
print(epoch)
for batch in train_loader:
optim.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_positions = batch['start_positions'].to(device)
end_positions = batch['end_positions'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
loss = outputs[0]
loss.backward()
optim.step()
print("Done")
model.eval()
model.save_pretrained("./")
data.tokenizer.save_pretrained("./")
subprocess.call(["git", "add","--all"])
subprocess.call(["git", "status"])
subprocess.call(["git", "commit", "-m", "First version of the your-model-name model and tokenizer."])
subprocess.call(["git", "push"])