import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
# Tokenizer and Model Initialization | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
# Load the model (Assuming it's already trained and saved in "./saved_model") | |
# If you don't have a trained model, comment out this line. The code will use the default BERT model | |
model = BertForSequenceClassification.from_pretrained("./saved_model") | |
# Predicting Function | |
def predict(text): | |
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
outputs = model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=-1) | |
return "AI-generated" if predictions.item() == 1 else "Human-written" | |
# Get user input and predict | |
user_input = input("Enter the text you want to classify: ") | |
print("Classified as:", predict(user_input)) |