Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from datasets import load_dataset, concatenate_datasets | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
# Predefined datasets | |
dataset_names = [ | |
'imdb', | |
'ag_news', | |
'squad', | |
'cnn_dailymail', | |
'wiki40b' | |
] | |
# Function to load and prepare datasets | |
def load_and_prepare_datasets(): | |
datasets = [load_dataset(name) for name in dataset_names] | |
# Concatenate train and validation datasets | |
train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds]) | |
eval_dataset = concatenate_datasets([ds['validation'] for ds in datasets if 'validation' in ds]) | |
return train_dataset, eval_dataset | |
# Function to preprocess data | |
def preprocess_function(examples): | |
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) | |
# Function to train the model | |
def train_model(): | |
global model, tokenizer | |
# Load model and tokenizer | |
model_name = 'gpt2' # You can choose another model if desired | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load and prepare datasets | |
train_dataset, eval_dataset = load_and_prepare_datasets() | |
# Preprocess the datasets | |
train_dataset = train_dataset.map(preprocess_function, batched=True) | |
eval_dataset = eval_dataset.map(preprocess_function, batched=True) | |
# Set training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
save_steps=1000, | |
evaluation_strategy="steps", | |
) | |
# Train the model | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
trainer.train() | |
return "Model trained successfully!" | |
# Function to generate text | |
def generate_text(prompt): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
output = model.generate(input_ids, max_length=100) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM Training and Text Generation") | |
with gr.Row(): | |
with gr.Column(): | |
train_button = gr.Button("Train Model") | |
output_message = gr.Textbox(label="Training Status", interactive=False) | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Enter prompt for text generation") | |
generate_button = gr.Button("Generate Text") | |
generated_output = gr.Textbox(label="Generated Text", interactive=False) | |
# Button actions | |
train_button.click(train_model, outputs=output_message) | |
generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output) | |
# Launch the app | |
demo.launch() |