File size: 2,057 Bytes
60f8cd4 8f46b8c 60f8cd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from preprocess import Model, SquadDataset
from transformers import DistilBertForQuestionAnswering
from torch.utils.data import DataLoader
from transformers import AdamW
import torch
import subprocess
data = Model()
train_contexts, train_questions, train_answers = data.ArrangeData("livecheckcontainer")
val_contexts, val_questions, val_answers = data.ArrangeData("livecheckcontainer")
print(train_answers)
train_answers, train_contexts = data.add_end_idx(train_answers, train_contexts)
val_answers, val_contexts = data.add_end_idx(val_answers, val_contexts)
train_encodings, val_encodings = data.Tokenizer(train_contexts, train_questions, val_contexts, val_questions)
train_encodings = data.add_token_positions(train_encodings, train_answers)
val_encodings = data.add_token_positions(val_encodings, val_answers)
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
optim = AdamW(model.parameters(), lr=5e-5)
for epoch in range(2):
print(epoch)
for batch in train_loader:
optim.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_positions = batch['start_positions'].to(device)
end_positions = batch['end_positions'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
loss = outputs[0]
loss.backward()
optim.step()
print("Done")
model.eval()
model.save_pretrained("./")
data.tokenizer.save_pretrained("./")
subprocess.call(["git", "add","--all"])
subprocess.call(["git", "status"])
subprocess.call(["git", "commit", "-m", "First version of the your-model-name model and tokenizer."])
subprocess.call(["git", "push"])
|