detrina-grad / app.py
portalniy-dev's picture
Create app.py
64cafec verified
raw
history blame
3.1 kB
import gradio as gr
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# Predefined datasets
dataset_names = [
'imdb',
'ag_news',
'squad',
'cnn_dailymail',
'wiki40b'
]
# Function to load and prepare datasets
def load_and_prepare_datasets():
datasets = [load_dataset(name) for name in dataset_names]
# Concatenate train and validation datasets
train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds])
eval_dataset = concatenate_datasets([ds['validation'] for ds in datasets if 'validation' in ds])
return train_dataset, eval_dataset
# Function to preprocess data
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
# Function to train the model
def train_model():
global model, tokenizer
# Load model and tokenizer
model_name = 'gpt2' # You can choose another model if desired
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load and prepare datasets
train_dataset, eval_dataset = load_and_prepare_datasets()
# Preprocess the datasets
train_dataset = train_dataset.map(preprocess_function, batched=True)
eval_dataset = eval_dataset.map(preprocess_function, batched=True)
# Set training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_steps=1000,
evaluation_strategy="steps",
)
# Train the model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
return "Model trained successfully!"
# Function to generate text
def generate_text(prompt):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = model.generate(input_ids, max_length=100)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# LLM Training and Text Generation")
with gr.Row():
with gr.Column():
train_button = gr.Button("Train Model")
output_message = gr.Textbox(label="Training Status", interactive=False)
with gr.Column():
prompt_input = gr.Textbox(label="Enter prompt for text generation")
generate_button = gr.Button("Generate Text")
generated_output = gr.Textbox(label="Generated Text", interactive=False)
# Button actions
train_button.click(train_model, outputs=output_message)
generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
# Launch the app
demo.launch()