pere commited on
Commit
11e7ab7
1 Parent(s): 29efbc9

run regressor

Browse files
Files changed (1) hide show
  1. run_regressor_bert.py +64 -0
run_regressor_bert.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import jsonlines
4
+ import os
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from datasets import Dataset
7
+ from tqdm import tqdm
8
+
9
+ def main(args):
10
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
11
+ model = AutoModelForSequenceClassification.from_pretrained(args.model_name, torch_dtype=torch.bfloat16)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+
15
+ # Load local jsonlines file
16
+ with jsonlines.open(args.input_file) as reader:
17
+ data = [line for line in reader]
18
+
19
+ # Convert list of dictionaries to dictionary of lists
20
+ data_dict = {key: [d[key] for d in data] for key in data[0]}
21
+ dataset = Dataset.from_dict(data_dict)
22
+
23
+ # Check how many lines have already been written to the output file
24
+ if os.path.exists(args.output_file):
25
+ with open(args.output_file, 'r') as f:
26
+ existing_lines = sum(1 for _ in f)
27
+ print(f"Skipping {existing_lines} already processed lines.")
28
+ else:
29
+ existing_lines = 0
30
+
31
+ # Skip already processed lines
32
+ if existing_lines > 0:
33
+ dataset = dataset.select(range(existing_lines, len(dataset)))
34
+
35
+ def compute_scores(batch):
36
+ inputs = tokenizer(batch[args.text_column], return_tensors="pt", padding="longest", truncation=True, max_length=args.max_length).to(device)
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits.squeeze(-1).float().cpu().numpy()
40
+
41
+ prefix = args.prefix
42
+ batch[f"{prefix}_score"] = logits.tolist()
43
+ batch[f"{prefix}_int_score"] = [int(round(max(0, min(score, 5)))) for score in logits]
44
+ return batch
45
+
46
+ # Process and write each batch incrementally
47
+ with jsonlines.open(args.output_file, mode='a') as writer:
48
+ for batch in tqdm(dataset.iter(batch_size=args.batch_size), total=(len(dataset) + args.batch_size - 1) // args.batch_size):
49
+ processed_batch = compute_scores(batch)
50
+ writer.write_all([dict(zip(batch.keys(), vals)) for vals in zip(*processed_batch.values())])
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+
55
+ parser.add_argument("--model_name", type=str, default="north/scandinavian_education_classifier_bert")
56
+ parser.add_argument("--input_file", type=str, required=True, help="Path to the input jsonlines file")
57
+ parser.add_argument("--output_file", type=str, required=True, help="Path to save the output jsonlines file")
58
+ parser.add_argument("--text_column", type=str, default="text")
59
+ parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length for tokenization")
60
+ parser.add_argument("--batch_size", type=int, default=1024, help="Batch size for processing")
61
+ parser.add_argument("--prefix", type=str, default="edu", help="Prefix for the score fields")
62
+
63
+ args = parser.parse_args()
64
+ main(args)