jaynopponep commited on
Commit
500e117
·
1 Parent(s): 2553f09

Turning model.py into MVP for today.

Browse files
Files changed (2) hide show
  1. .idea/.name +1 -0
  2. model.py +6 -73
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ model.py
model.py CHANGED
@@ -1,88 +1,21 @@
1
- import pandas as pd
2
  import torch
3
- from sklearn.model_selection import train_test_split
4
- from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
5
 
6
- # Read the dataset
7
- df = pd.read_csv('Training_Essay_Data.csv') # Make sure the file name is correct
8
-
9
- # Splitting the dataset
10
- train_df, eval_df = train_test_split(df, test_size=0.1)
11
-
12
- # Tokenizer
13
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
14
-
15
-
16
- # Tokenize function
17
- def tokenize_function(examples):
18
- return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
19
-
20
-
21
- # Tokenize the dataset
22
- train_encodings = tokenize_function(train_df)
23
- eval_encodings = tokenize_function(eval_df)
24
-
25
-
26
- # Essay dataset class
27
- class EssayDataset(torch.utils.data.Dataset):
28
- def __init__(self, encodings, labels):
29
- self.encodings = encodings
30
- self.labels = labels
31
-
32
- def __getitem__(self, idx):
33
- item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
34
- item['labels'] = torch.tensor(int(self.labels[idx]))
35
- return item
36
-
37
- def __len__(self):
38
- return len(self.labels)
39
-
40
-
41
- # Dataset preparation
42
- train_dataset = EssayDataset(train_encodings, train_df['label'].tolist())
43
- eval_dataset = EssayDataset(eval_encodings, eval_df['label'].tolist())
44
-
45
- # Model
46
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
47
 
48
- # Training arguments
49
- training_args = TrainingArguments(
50
- output_dir='./results',
51
- num_train_epochs=3,
52
- per_device_train_batch_size=16,
53
- per_device_eval_batch_size=64,
54
- warmup_steps=500,
55
- weight_decay=0.01,
56
- logging_dir='./logs',
57
- evaluation_strategy="epoch"
58
- )
59
-
60
- # Trainer
61
- trainer = Trainer(
62
- model=model,
63
- args=training_args,
64
- train_dataset=train_dataset,
65
- eval_dataset=eval_dataset
66
- )
67
-
68
- # Train the model
69
- trainer.train()
70
-
71
- # Save the model
72
- model.save_pretrained("./saved_model")
73
-
74
- # Load the model for prediction
75
  model = BertForSequenceClassification.from_pretrained("./saved_model")
76
 
77
-
78
- # Predicting
79
  def predict(text):
80
  inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
81
  outputs = model(**inputs)
82
  predictions = torch.argmax(outputs.logits, dim=-1)
83
  return "AI-generated" if predictions.item() == 1 else "Human-written"
84
 
85
-
86
  # Get user input and predict
87
  user_input = input("Enter the text you want to classify: ")
88
- print("Classified as:", predict(user_input))
 
 
1
  import torch
2
+ from transformers import BertTokenizer, BertForSequenceClassification
 
3
 
4
+ # Tokenizer and Model Initialization
 
 
 
 
 
 
5
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
7
 
8
+ # Load the model (Assuming it's already trained and saved in "./saved_model")
9
+ # If you don't have a trained model, comment out this line. The code will use the default BERT model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model = BertForSequenceClassification.from_pretrained("./saved_model")
11
 
12
+ # Predicting Function
 
13
  def predict(text):
14
  inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
15
  outputs = model(**inputs)
16
  predictions = torch.argmax(outputs.logits, dim=-1)
17
  return "AI-generated" if predictions.item() == 1 else "Human-written"
18
 
 
19
  # Get user input and predict
20
  user_input = input("Enter the text you want to classify: ")
21
+ print("Classified as:", predict(user_input))