import time import torch from transformers import BertForSequenceClassification, AdamW from torch.utils.data import DataLoader, TensorDataset from transformers import BertTokenizer import gradio as gr import pandas as pd import os import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model = BertForSequenceClassification.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model.to(device) optimizer = AdamW(model.parameters(), lr=1e-5) global_data = None def load_data(file): global global_data df = pd.read_csv(file) inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản labels = torch.tensor(df['lable'].tolist()).long() # global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels) print(global_data) def get_dataloader(start, end, batch_size=8): global global_data subset = torch.utils.data.Subset(global_data, range(start, end)) return DataLoader(subset, batch_size=batch_size) @spaces.GPU def train_batch(dataloader): model.train() start_time = time.time() for step, batch in enumerate(dataloader): input_ids, attention_mask, labels = batch input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) optimizer.zero_grad() outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() elapsed_time = time.time() - start_time if elapsed_time > 55: # Dừng trước 60 giây để lưu checkpoint torch.save(model.state_dict(), "./checkpoint/model.pt") return False, "Checkpoint saved. Training paused." return True, "Batch training completed." def train_step(file=None): if file: load_data(file) print(global_data) start_idx = 0 batch_size = 8 total_samples = len(global_data) counting = 0 while start_idx < total_samples: print("Step:", counting) counting = counting + 1 end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop dataloader = get_dataloader(start_idx, end_idx, batch_size) success, message = train_batch(dataloader) if not success: return message start_idx = end_idx time.sleep(5) # Nghỉ 5 giây giữa các phiên huấn luyện torch.save(model.state_dict(), "./checkpoint/model.pt") return "Training completed and model saved." if __name__ == "__main__": iface = gr.Interface( fn=train_step, inputs=gr.File(label="Upload CSV"), outputs="text" ) iface.launch()