test-gpt / app.py
SoulXS2's picture
Update app.py
2b9a0fe verified
from docx import Document
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import torch
import gradio as gr
# Read data from DOCX file
def read_docx(file_path):
doc = Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
# Load and tokenize your conversational dataset
docx_file_path = "test.txt"
text_data = read_docx(docx_file_path)
# Tokenization
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenized_text = tokenizer.encode(text_data, return_tensors="pt")
# Define dataset
class TextDataset(torch.utils.data.Dataset):
def __init__(self, tokenized_text):
self.tokenized_text = tokenized_text
def __len__(self):
return len(self.tokenized_text)
def __getitem__(self, idx):
return self.tokenized_text[idx]
dataset = TextDataset(tokenized_text)
# Define data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Define training arguments
training_args = TrainingArguments(
output_dir="./output",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10_000,
save_total_limit=2,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Fine-tune the model
trainer.train()
# Evaluate the fine-tuned model
trainer.evaluate()
# Define the chatbot function
def chatbot(input_text):
# Tokenize input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate response from the model
output_ids = model.generate(input_ids, max_length=50, pad_token_id=tokenizer.eos_token_id)
# Decode the generated response
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
# Create the Gradio interface
chatbot_interface = gr.Interface(chatbot, "textbox", "textbox", title="Chatbot")
# Launch the Gradio interface
chatbot_interface.launch()