system HF staff commited on
Commit
8a9d09e
1 Parent(s): e25535a

Update train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +70 -0
train_script.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This example loads the pre-trained bert-base-nli-mean-tokens models from the server.
3
+ It then fine-tunes this model for some epochs on the STS benchmark dataset.
4
+ """
5
+ from torch.utils.data import DataLoader
6
+ import math
7
+ from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
8
+ from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
9
+ from sentence_transformers.readers import STSDataReader
10
+ import logging
11
+ from datetime import datetime
12
+
13
+
14
+ #### Just some code to print debug information to stdout
15
+ logging.basicConfig(format='%(asctime)s - %(message)s',
16
+ datefmt='%Y-%m-%d %H:%M:%S',
17
+ level=logging.INFO,
18
+ handlers=[LoggingHandler()])
19
+ #### /print debug information to stdout
20
+
21
+ # Read the dataset
22
+ #model_name = 'bert-base-nli-stsb-mean-tokens'
23
+ model_name = "../saved_models"
24
+ train_batch_size = 32
25
+ num_epochs = 4
26
+ model_save_path = 'output/quora_continue_training-'+model_name+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
27
+ sts_reader = STSDataReader('../data/quora', normalize_scores=True, s1_col_idx=4, s2_col_idx=5, score_col_idx=6, max_score=1)
28
+
29
+ # Load a pre-trained sentence transformer model
30
+ model = SentenceTransformer(model_name)
31
+
32
+ # Convert the dataset to a DataLoader ready for training
33
+ logging.info("Read Quora train dataset")
34
+ train_data = SentencesDataset(sts_reader.get_examples('train.csv'), model)
35
+ train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
36
+ train_loss = losses.CosineSimilarityLoss(model=model)
37
+
38
+
39
+ logging.info("Read Quora dev dataset")
40
+ dev_data = SentencesDataset(examples=sts_reader.get_examples('dev.csv'), model=model)
41
+ dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=train_batch_size)
42
+ evaluator = EmbeddingSimilarityEvaluator(dev_dataloader)
43
+
44
+
45
+ # Configure the training. We skip evaluation in this example
46
+ warmup_steps = math.ceil(len(train_data)*num_epochs/train_batch_size*0.1) #10% of train data for warm-up
47
+ logging.info("Warmup-steps: {}".format(warmup_steps))
48
+
49
+
50
+ # Train the model
51
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
52
+ evaluator=evaluator,
53
+ epochs=num_epochs,
54
+ evaluation_steps=1000,
55
+ warmup_steps=warmup_steps,
56
+ output_path=model_save_path)
57
+
58
+
59
+ ##############################################################################
60
+ #
61
+ # Load the stored model and evaluate its performance on STS benchmark dataset
62
+ #
63
+ ##############################################################################
64
+ #
65
+ # model = SentenceTransformer(model_save_path)
66
+ # test_data = SentencesDataset(examples=sts_reader.get_examples("sts-test.csv"), model=model)
67
+ # test_dataloader = DataLoader(test_data, shuffle=False, batch_size=train_batch_size)
68
+ # evaluator = EmbeddingSimilarityEvaluator(test_dataloader)
69
+ # model.evaluate(evaluator)
70
+