train_for_fun / app.py
ShynBui's picture
Add application file
eb7e2af
raw
history blame
2.83 kB
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
global_data = None
def load_data(file):
global global_data
df = pd.read_csv(file)
inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
labels = torch.tensor(df['lable'].tolist()).long() #
global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
print(global_data)
def get_dataloader(start, end, batch_size=8):
global global_data
subset = torch.utils.data.Subset(global_data, range(start, end))
return DataLoader(subset, batch_size=batch_size)
@spaces.GPU
def train_batch(dataloader):
model.train()
start_time = time.time()
for step, batch in enumerate(dataloader):
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
elapsed_time = time.time() - start_time
if elapsed_time > 55: # Dừng trước 60 giây để lưu checkpoint
torch.save(model.state_dict(), "./checkpoint/model.pt")
return False, "Checkpoint saved. Training paused."
return True, "Batch training completed."
def train_step(file=None):
if file:
load_data(file)
print(global_data)
start_idx = 0
batch_size = 8
total_samples = len(global_data)
counting = 0
while start_idx < total_samples:
print("Step:", counting)
counting = counting + 1
end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
dataloader = get_dataloader(start_idx, end_idx, batch_size)
success, message = train_batch(dataloader)
if not success:
return message
start_idx = end_idx
time.sleep(5) # Nghỉ 5 giây giữa các phiên huấn luyện
torch.save(model.state_dict(), "./checkpoint/model.pt")
return "Training completed and model saved."
if __name__ == "__main__":
iface = gr.Interface(
fn=train_step,
inputs=gr.File(label="Upload CSV"),
outputs="text"
)
iface.launch()