tomaarsen HF staff commited on
Commit
ac307cc
·
verified ·
1 Parent(s): 8b67887

Upload train_st_gooaq.py

Browse files
Files changed (1) hide show
  1. train_st_gooaq.py +87 -0
train_st_gooaq.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 onwards Answer.AI, LightOn, and contributors
2
+ # License: Apache-2.0
3
+
4
+ import argparse
5
+
6
+ from datasets import load_dataset
7
+ from sentence_transformers import (
8
+ SentenceTransformer,
9
+ SentenceTransformerTrainer,
10
+ SentenceTransformerTrainingArguments,
11
+ )
12
+ from sentence_transformers.evaluation import NanoBEIREvaluator
13
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
14
+ from sentence_transformers.training_args import BatchSamplers
15
+
16
+ def main():
17
+ # parse the lr & model name
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--lr", type=float, default=8e-5)
20
+ parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base")
21
+ args = parser.parse_args()
22
+ lr = args.lr
23
+ model_name = args.model_name
24
+ model_shortname = model_name.split("/")[-1]
25
+
26
+ # 1. Load a model to finetune
27
+ model = SentenceTransformer(model_name)
28
+
29
+ # 2. Load a dataset to finetune on
30
+ dataset = load_dataset("sentence-transformers/gooaq", split="train")
31
+ dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
32
+ train_dataset = dataset_dict["train"]
33
+ eval_dataset = dataset_dict["test"]
34
+
35
+ # 3. Define a loss function
36
+ loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128) # Increase mini_batch_size if you have enough VRAM
37
+
38
+ run_name = f"{model_shortname}-gooaq-{lr}"
39
+ # 4. (Optional) Specify training arguments
40
+ args = SentenceTransformerTrainingArguments(
41
+ # Required parameter:
42
+ output_dir=f"output/{model_shortname}/{run_name}",
43
+ # Optional training parameters:
44
+ num_train_epochs=1,
45
+ per_device_train_batch_size=2048,
46
+ per_device_eval_batch_size=2048,
47
+ learning_rate=lr,
48
+ warmup_ratio=0.05,
49
+ fp16=False, # Set to False if GPU can't handle FP16
50
+ bf16=True, # Set to True if GPU supports BF16
51
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
52
+ # Optional tracking/debugging parameters:
53
+ eval_strategy="steps",
54
+ eval_steps=50,
55
+ save_strategy="steps",
56
+ save_steps=50,
57
+ save_total_limit=2,
58
+ logging_steps=10,
59
+ run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
60
+ )
61
+
62
+ # 5. (Optional) Create an evaluator & evaluate the base model
63
+ dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"])
64
+ dev_evaluator(model)
65
+
66
+ # 6. Create a trainer & train
67
+ trainer = SentenceTransformerTrainer(
68
+ model=model,
69
+ args=args,
70
+ train_dataset=train_dataset,
71
+ eval_dataset=eval_dataset,
72
+ loss=loss,
73
+ evaluator=dev_evaluator,
74
+ )
75
+ trainer.train()
76
+
77
+ # 7. (Optional) Evaluate the trained model on the evaluator after training
78
+ dev_evaluator(model)
79
+
80
+ # 8. Save the model
81
+ model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
82
+
83
+ # 9. (Optional) Push it to the Hugging Face Hub
84
+ model.push_to_hub(run_name, private=False)
85
+
86
+ if __name__ == "__main__":
87
+ main()