run regressor
Browse files- run_regressor_bert.py +64 -0
run_regressor_bert.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import jsonlines
|
4 |
+
import os
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
+
from datasets import Dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
11 |
+
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, torch_dtype=torch.bfloat16)
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
model.to(device)
|
14 |
+
|
15 |
+
# Load local jsonlines file
|
16 |
+
with jsonlines.open(args.input_file) as reader:
|
17 |
+
data = [line for line in reader]
|
18 |
+
|
19 |
+
# Convert list of dictionaries to dictionary of lists
|
20 |
+
data_dict = {key: [d[key] for d in data] for key in data[0]}
|
21 |
+
dataset = Dataset.from_dict(data_dict)
|
22 |
+
|
23 |
+
# Check how many lines have already been written to the output file
|
24 |
+
if os.path.exists(args.output_file):
|
25 |
+
with open(args.output_file, 'r') as f:
|
26 |
+
existing_lines = sum(1 for _ in f)
|
27 |
+
print(f"Skipping {existing_lines} already processed lines.")
|
28 |
+
else:
|
29 |
+
existing_lines = 0
|
30 |
+
|
31 |
+
# Skip already processed lines
|
32 |
+
if existing_lines > 0:
|
33 |
+
dataset = dataset.select(range(existing_lines, len(dataset)))
|
34 |
+
|
35 |
+
def compute_scores(batch):
|
36 |
+
inputs = tokenizer(batch[args.text_column], return_tensors="pt", padding="longest", truncation=True, max_length=args.max_length).to(device)
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model(**inputs)
|
39 |
+
logits = outputs.logits.squeeze(-1).float().cpu().numpy()
|
40 |
+
|
41 |
+
prefix = args.prefix
|
42 |
+
batch[f"{prefix}_score"] = logits.tolist()
|
43 |
+
batch[f"{prefix}_int_score"] = [int(round(max(0, min(score, 5)))) for score in logits]
|
44 |
+
return batch
|
45 |
+
|
46 |
+
# Process and write each batch incrementally
|
47 |
+
with jsonlines.open(args.output_file, mode='a') as writer:
|
48 |
+
for batch in tqdm(dataset.iter(batch_size=args.batch_size), total=(len(dataset) + args.batch_size - 1) // args.batch_size):
|
49 |
+
processed_batch = compute_scores(batch)
|
50 |
+
writer.write_all([dict(zip(batch.keys(), vals)) for vals in zip(*processed_batch.values())])
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
|
55 |
+
parser.add_argument("--model_name", type=str, default="north/scandinavian_education_classifier_bert")
|
56 |
+
parser.add_argument("--input_file", type=str, required=True, help="Path to the input jsonlines file")
|
57 |
+
parser.add_argument("--output_file", type=str, required=True, help="Path to save the output jsonlines file")
|
58 |
+
parser.add_argument("--text_column", type=str, default="text")
|
59 |
+
parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length for tokenization")
|
60 |
+
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size for processing")
|
61 |
+
parser.add_argument("--prefix", type=str, default="edu", help="Prefix for the score fields")
|
62 |
+
|
63 |
+
args = parser.parse_args()
|
64 |
+
main(args)
|