|
from preprocess import Model, SquadDataset |
|
from transformers import DistilBertForQuestionAnswering |
|
from torch.utils.data import DataLoader |
|
from transformers import AdamW |
|
import torch |
|
import subprocess |
|
|
|
data = Model() |
|
train_contexts, train_questions, train_answers = data.ArrangeData("livecheckcontainer") |
|
val_contexts, val_questions, val_answers = data.ArrangeData("livecheckcontainer") |
|
print(train_answers) |
|
|
|
train_answers, train_contexts = data.add_end_idx(train_answers, train_contexts) |
|
val_answers, val_contexts = data.add_end_idx(val_answers, val_contexts) |
|
|
|
train_encodings, val_encodings = data.Tokenizer(train_contexts, train_questions, val_contexts, val_questions) |
|
|
|
train_encodings = data.add_token_positions(train_encodings, train_answers) |
|
val_encodings = data.add_token_positions(val_encodings, val_answers) |
|
|
|
train_dataset = SquadDataset(train_encodings) |
|
val_dataset = SquadDataset(val_encodings) |
|
|
|
|
|
|
|
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
model.to(device) |
|
model.train() |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) |
|
|
|
optim = AdamW(model.parameters(), lr=5e-5) |
|
|
|
for epoch in range(2): |
|
print(epoch) |
|
for batch in train_loader: |
|
optim.zero_grad() |
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
start_positions = batch['start_positions'].to(device) |
|
end_positions = batch['end_positions'].to(device) |
|
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) |
|
loss = outputs[0] |
|
loss.backward() |
|
optim.step() |
|
print("Done") |
|
model.eval() |
|
model.save_pretrained("./") |
|
data.tokenizer.save_pretrained("./") |
|
|
|
|
|
subprocess.call(["git", "add","--all"]) |
|
subprocess.call(["git", "status"]) |
|
subprocess.call(["git", "commit", "-m", "First version of the your-model-name model and tokenizer."]) |
|
subprocess.call(["git", "push"]) |
|
|
|
|