import gradio as gr import torch from datasets import load_dataset, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments # Predefined datasets dataset_names = [ 'imdb', 'ag_news', 'squad', 'cnn_dailymail', 'wiki40b' ] # Function to load and prepare datasets def load_and_prepare_datasets(): datasets = [load_dataset(name) for name in dataset_names] # Concatenate train and validation datasets train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds]) eval_dataset = concatenate_datasets([ds['validation'] for ds in datasets if 'validation' in ds]) return train_dataset, eval_dataset # Function to preprocess data def preprocess_function(examples): return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) # Function to train the model def train_model(): global model, tokenizer # Load model and tokenizer model_name = 'gpt2' # You can choose another model if desired model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load and prepare datasets train_dataset, eval_dataset = load_and_prepare_datasets() # Preprocess the datasets train_dataset = train_dataset.map(preprocess_function, batched=True) eval_dataset = eval_dataset.map(preprocess_function, batched=True) # Set training arguments training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, save_steps=1000, evaluation_strategy="steps", ) # Train the model trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) trainer.train() return "Model trained successfully!" # Function to generate text def generate_text(prompt): input_ids = tokenizer.encode(prompt, return_tensors='pt') output = model.generate(input_ids, max_length=100) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text # Gradio interface with gr.Blocks() as demo: gr.Markdown("# LLM Training and Text Generation") with gr.Row(): with gr.Column(): train_button = gr.Button("Train Model") output_message = gr.Textbox(label="Training Status", interactive=False) with gr.Column(): prompt_input = gr.Textbox(label="Enter prompt for text generation") generate_button = gr.Button("Generate Text") generated_output = gr.Textbox(label="Generated Text", interactive=False) # Button actions train_button.click(train_model, outputs=output_message) generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output) # Launch the app demo.launch()