test-gpt / app.py
SoulXS2's picture
Update app.py
4a0e670 verified
raw
history blame
1.68 kB
from docx import Document
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch
import gradio as gr
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Load and tokenize your conversational dataset
dataset = TextDataset(tokenizer=tokenizer, file_path="your_dataset.txt")
# Load and tokenize your conversational dataset
dataset = TextDataset(tokenizer=tokenizer, file_path="your_dataset.txt")
# Define data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Define training arguments
training_args = TrainingArguments(
output_dir="./output",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10_000,
save_total_limit=2,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Fine-tune the model
trainer.train()
# Evaluate the fine-tuned model
trainer.evaluate()
# Define the chatbot function
def chatbot(input_text):
# Tokenize input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate response from the model
output_ids = model.generate(input_ids, max_length=50, pad_token_id=tokenizer.eos_token_id)
# Decode the generated response
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
# Create the Gradio interface
chatbot_interface = gr.Interface(chatbot, "textbox", "textbox", title="Chatbot")
# Launch the Gradio interface
chatbot_interface.launch()