tomaarsen HF staff commited on
Commit
4fe3f1d
·
1 Parent(s): 35f487a

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +61 -0
train.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import TrainingArguments
3
+ from span_marker import SpanMarkerModel, Trainer
4
+
5
+
6
+ def main() -> None:
7
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
8
+ dataset = load_dataset("acronym_identification").rename_column("labels", "ner_tags")
9
+ labels = dataset["train"].features["ner_tags"].feature.names
10
+
11
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
12
+ model_name = "bert-base-cased"
13
+ model = SpanMarkerModel.from_pretrained(
14
+ model_name,
15
+ labels=labels,
16
+ # SpanMarker hyperparameters:
17
+ model_max_length=256,
18
+ marker_max_length=128,
19
+ entity_max_length=8,
20
+ )
21
+
22
+ # Prepare the 🤗 transformers training arguments
23
+ args = TrainingArguments(
24
+ output_dir=f"models/span_marker_bert_base_acronyms",
25
+ run_name=f"bb_acronyms",
26
+ # Training Hyperparameters:
27
+ learning_rate=5e-5,
28
+ per_device_train_batch_size=32,
29
+ per_device_eval_batch_size=32,
30
+ num_train_epochs=2,
31
+ weight_decay=0.01,
32
+ warmup_ratio=0.1,
33
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
34
+ # Other Training parameters
35
+ logging_first_step=True,
36
+ logging_steps=50,
37
+ evaluation_strategy="steps",
38
+ save_strategy="steps",
39
+ eval_steps=200,
40
+ save_total_limit=2,
41
+ dataloader_num_workers=2,
42
+ )
43
+
44
+ # Initialize the trainer using our model, training args & dataset, and train
45
+ trainer = Trainer(
46
+ model=model,
47
+ args=args,
48
+ train_dataset=dataset["train"],
49
+ eval_dataset=dataset["validation"],
50
+ )
51
+ trainer.train()
52
+ trainer.save_model(f"models/span_marker_bert_base_acronyms/checkpoint-final")
53
+
54
+ # Compute & save the metrics on the test set
55
+ metrics = trainer.evaluate()
56
+ trainer.save_metrics("validation", metrics)
57
+ trainer.create_model_card()
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()