Upload 8 files
Browse files- Makefile +26 -0
- README.md +60 -0
- build_dataset.py +81 -0
- config.yaml +11 -0
- environment.yaml +16 -0
- fine-tuning.py +73 -0
- inference.py +57 -0
- utils.py +57 -0
Makefile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: data train eval inference run clean
|
2 |
+
|
3 |
+
data:
|
4 |
+
@echo "Creating dataset from google/sentence_compressiom.."
|
5 |
+
python -m build_dataset
|
6 |
+
|
7 |
+
train:
|
8 |
+
@echo "Training google/t5-small model for sentence compression.."
|
9 |
+
python -m fine-tuning
|
10 |
+
|
11 |
+
eval:
|
12 |
+
@echo "Evaluation on test set.."
|
13 |
+
python -m utils
|
14 |
+
|
15 |
+
inference:
|
16 |
+
@echo "Performing model inference on evaluation data.."
|
17 |
+
python -m inference
|
18 |
+
|
19 |
+
run: clean data train eval inference
|
20 |
+
|
21 |
+
clean:
|
22 |
+
@find . -name "*.pyc" -exec rm {} \;
|
23 |
+
@rm -rf dataset/preprocessed/* checkpoints/* results/*;
|
24 |
+
|
25 |
+
zip:
|
26 |
+
@tar --exclude=".[^/]*" -czvf "AnshuKumar-RingCentral-$(shell date +"%Y%m%d").tar.gz" *
|
README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Getting Started
|
2 |
+
|
3 |
+
### Installation
|
4 |
+
|
5 |
+
1. conda environment
|
6 |
+
```
|
7 |
+
conda env create --name NAME --file=environment.yaml
|
8 |
+
```
|
9 |
+
|
10 |
+
|
11 |
+
The Project is designed around several scripts that simulate a typical machine learning workflow. Starting with data preparation after preparing data, training model, evaluation and inference model. `google/t5-small` model was being trained on above dataset for `10` epochs. Later inference ran on evaluation data, performance metrics and evaluation results were stored inside `result` subdirectory of `project` directory.
|
12 |
+
|
13 |
+
I added Makefile which can be used to run python scripts separately using following bash commands.
|
14 |
+
|
15 |
+
```bash
|
16 |
+
make data
|
17 |
+
make train
|
18 |
+
make eval
|
19 |
+
make inference
|
20 |
+
```
|
21 |
+
|
22 |
+
`run` is a bash command which can aggregately run entire project.
|
23 |
+
|
24 |
+
```bash
|
25 |
+
make run
|
26 |
+
```
|
27 |
+
|
28 |
+
`clean` is a bash command which can be used to clean the previous runs.
|
29 |
+
|
30 |
+
```bash
|
31 |
+
make clean
|
32 |
+
```
|
33 |
+
|
34 |
+
Performance metrics stores into `performance.json` file inside `results` directory.
|
35 |
+
|
36 |
+
```json
|
37 |
+
{
|
38 |
+
"rouge1": 0.79689240266461,
|
39 |
+
"rouge2": 0.7606140631154827,
|
40 |
+
"rougeL": 0.7733855633904199,
|
41 |
+
"rougeLsum": 0.7734703253159519
|
42 |
+
}
|
43 |
+
```
|
44 |
+
|
45 |
+
And also, `eval_results.csv` containing predictions of evaluation file.
|
46 |
+
|
47 |
+
| original | compressed | predictions |
|
48 |
+
|-----------|------------|-------------|
|
49 |
+
| sentence1 | compress1 | prediction1 |
|
50 |
+
| sentence2 | compress2 | prediction2 |
|
51 |
+
| : | : | : |
|
52 |
+
|
53 |
+
### References:
|
54 |
+
1. https://github.com/google-research-datasets/sentence-compression
|
55 |
+
2. https://huggingface.co/docs/transformers/en/tasks/summarization
|
56 |
+
|
57 |
+
### Note:
|
58 |
+
Download trained checkpoint from given drive link [checkpoint](https://drive.google.com/drive/folders/1yrl0VtmM9BtT4aU2Z5vLs6doz35MMxvM?usp=drive_link)
|
59 |
+
|
60 |
+
|
build_dataset.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import csv
|
6 |
+
import yaml
|
7 |
+
from git import Repo
|
8 |
+
import gzip
|
9 |
+
|
10 |
+
|
11 |
+
fieldnames = ['original','compressed']
|
12 |
+
|
13 |
+
def to_csv_record(writer, buffer):
|
14 |
+
record = json.loads(buffer)
|
15 |
+
writer.writerow(dict(
|
16 |
+
original=record['graph']['sentence'],
|
17 |
+
compressed=record['compression']['text']))
|
18 |
+
|
19 |
+
def build_dataset(rawdata_dir, preprocessed_data_dir):
|
20 |
+
print("Data Preparation...")
|
21 |
+
os.makedirs(preprocessed_data_dir, exist_ok=True)
|
22 |
+
with open(os.path.join(preprocessed_data_dir, 'training_data.csv'),'w') as csvfile:
|
23 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
24 |
+
writer.writeheader()
|
25 |
+
for rawdata_files in glob.glob(f'{rawdata_dir}/data/**train**.json'):
|
26 |
+
with open(rawdata_files) as raw_contents:
|
27 |
+
buffer = ''
|
28 |
+
for line in raw_contents:
|
29 |
+
if line.strip()=='':
|
30 |
+
to_csv_record(writer, buffer)
|
31 |
+
buffer = ''
|
32 |
+
else:
|
33 |
+
buffer += line
|
34 |
+
if len(buffer)>0:
|
35 |
+
to_csv_record(writer, buffer)
|
36 |
+
|
37 |
+
with open(os.path.join(preprocessed_data_dir, 'eval_data.csv'),'w') as csvfile:
|
38 |
+
writer = csv.DictWriter(csvfile, fieldnames=['original','compressed'])
|
39 |
+
writer.writeheader()
|
40 |
+
with open(f'{rawdata_dir}/data/comp-data.eval.json') as raw_contents:
|
41 |
+
buffer = ''
|
42 |
+
for line in raw_contents:
|
43 |
+
if line.strip()=='':
|
44 |
+
to_csv_record(writer, buffer)
|
45 |
+
buffer = ''
|
46 |
+
else: buffer += line
|
47 |
+
if len(buffer)>0: to_csv_record(writer, buffer)
|
48 |
+
|
49 |
+
def decompressing_rawdata(rawdata_dir):
|
50 |
+
print("Decompression...")
|
51 |
+
compressed_files = glob.glob(rawdata_dir + "/data/*.json.gz")
|
52 |
+
for compressed_file_path in compressed_files:
|
53 |
+
output_file_path = os.path.splitext(compressed_file_path)[0]
|
54 |
+
with gzip.open(compressed_file_path, 'rb') as comp_file:
|
55 |
+
compressed_content = comp_file.read()
|
56 |
+
with open(output_file_path, 'wb') as output_file:
|
57 |
+
output_file.write(compressed_content)
|
58 |
+
os.remove(compressed_file_path)
|
59 |
+
|
60 |
+
def download_rawdata(git_url, rawdata_dir):
|
61 |
+
os.makedirs(rawdata_dir, exist_ok=True)
|
62 |
+
print("Data Cloning...")
|
63 |
+
current_dir = os.getcwd()
|
64 |
+
try:
|
65 |
+
os.chdir(rawdata_dir)
|
66 |
+
Repo.clone_from(git_url, '.')
|
67 |
+
except Exception as e:
|
68 |
+
print("Error:", e)
|
69 |
+
finally:
|
70 |
+
os.chdir(current_dir)
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
config = yaml.safe_load(open("config.yaml", "r"))
|
74 |
+
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"])
|
75 |
+
rawdata_git = config["SENTENCE_COMPRESSION"]["DATA"]["RAW_DATA"]
|
76 |
+
preprocessed_data_dir = os.path.join(PROJECT_DIR, config["SENTENCE_COMPRESSION"]["DATA"]["CLEAN_DATA"])
|
77 |
+
rawdata_dir = os.path.join(PROJECT_DIR, config["SENTENCE_COMPRESSION"]["DATA"]["RAW_DIR"])
|
78 |
+
download_rawdata(rawdata_git, rawdata_dir)
|
79 |
+
decompressing_rawdata(rawdata_dir)
|
80 |
+
build_dataset(rawdata_dir, preprocessed_data_dir)
|
81 |
+
shutil.rmtree(rawdata_dir)
|
config.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SENTENCE_COMPRESSION:
|
2 |
+
PROJECT_DIR: os.getcwd()
|
3 |
+
DATA:
|
4 |
+
RAW_DATA: https://github.com/google-research-datasets/sentence-compression.git
|
5 |
+
RAW_DIR: dataset/rawdata
|
6 |
+
CLEAN_DATA: dataset/preprocessed
|
7 |
+
TRAINING:
|
8 |
+
INFERENCE:
|
9 |
+
MODEL_PATH: checkpoints
|
10 |
+
OUTPUT:
|
11 |
+
RESULT: results
|
environment.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: compression
|
2 |
+
dependencies:
|
3 |
+
- python=3.10.10
|
4 |
+
- pip=23.3.1
|
5 |
+
- pip:
|
6 |
+
- transformers==4.37.2
|
7 |
+
- seaborn==0.13.2
|
8 |
+
- scikit-learn==1.4.0
|
9 |
+
- pandas==2.2.1
|
10 |
+
- GitPython==3.1.43
|
11 |
+
- torch==2.2.2
|
12 |
+
- evaluate==0.4.2
|
13 |
+
- accelerate==0.27.0
|
14 |
+
- absl-py==2.1.0
|
15 |
+
- nltk==3.8.1
|
16 |
+
- rouge_score==0.1.2
|
fine-tuning.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, Trainer
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from transformers import DataCollatorForSeq2Seq
|
8 |
+
import evaluate
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
checkpoint = "google-t5/t5-small"
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
14 |
+
|
15 |
+
prefix = "summarize the following sentence: "
|
16 |
+
def preprocess_function(examples):
|
17 |
+
inputs = prefix + examples["original"]
|
18 |
+
model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
|
19 |
+
labels = tokenizer(text_target=examples["compressed"], max_length=128, truncation=True)
|
20 |
+
model_inputs["labels"] = labels["input_ids"]
|
21 |
+
return model_inputs
|
22 |
+
|
23 |
+
def compute_metrics(eval_pred):
|
24 |
+
predictions, labels = eval_pred
|
25 |
+
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
26 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
27 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
28 |
+
rouge = evaluate.load("rouge")
|
29 |
+
result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
30 |
+
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
|
31 |
+
result["gen_len"] = np.mean(prediction_lens)
|
32 |
+
return {k: round(v, 4) for k, v in result.items()}
|
33 |
+
|
34 |
+
def main():
|
35 |
+
print("Data Loading...")
|
36 |
+
config = yaml.safe_load(open("config.yaml", "r"))
|
37 |
+
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"])
|
38 |
+
data_dir = os.path.join(PROJECT_DIR, config["SENTENCE_COMPRESSION"]["DATA"]["CLEAN_DATA"])
|
39 |
+
data = pd.read_csv(os.path.join(data_dir, 'training_data.csv'))
|
40 |
+
print("Tokenization started...")
|
41 |
+
data_preprocessed = data.apply(preprocess_function, axis=1)
|
42 |
+
print("Test data preprocessing...")
|
43 |
+
train_tokenized, test_tokenized = train_test_split(data_preprocessed, test_size=0.2)
|
44 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
|
45 |
+
print("Model Loading...")
|
46 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
47 |
+
training_args = Seq2SeqTrainingArguments(
|
48 |
+
output_dir="checkpoints",
|
49 |
+
evaluation_strategy="epoch",
|
50 |
+
learning_rate=2e-5,
|
51 |
+
per_device_train_batch_size=16,
|
52 |
+
per_device_eval_batch_size=4,
|
53 |
+
weight_decay=0.01,
|
54 |
+
save_total_limit=3,
|
55 |
+
num_train_epochs=10,
|
56 |
+
predict_with_generate=True,
|
57 |
+
fp16=True,
|
58 |
+
push_to_hub=False,
|
59 |
+
)
|
60 |
+
trainer = Seq2SeqTrainer(
|
61 |
+
model=model,
|
62 |
+
args=training_args,
|
63 |
+
train_dataset=train_tokenized.values,
|
64 |
+
eval_dataset=test_tokenized.values,
|
65 |
+
tokenizer=tokenizer,
|
66 |
+
data_collator=data_collator,
|
67 |
+
compute_metrics=compute_metrics,
|
68 |
+
)
|
69 |
+
trainer.train()
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
main()
|
73 |
+
|
inference.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import json
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import evaluate
|
7 |
+
from transformers import pipeline
|
8 |
+
|
9 |
+
|
10 |
+
def load_pipeline(model_path):
|
11 |
+
summarizer = pipeline("summarization", model=model_path, device=0)
|
12 |
+
return summarizer
|
13 |
+
|
14 |
+
def infernece(pipeline, eval_data):
|
15 |
+
prompt = "summarize the following sentence:"
|
16 |
+
sentences = eval_data['original'].tolist()
|
17 |
+
compressed = eval_data['compressed'].tolist()
|
18 |
+
predictions = []
|
19 |
+
for sent in sentences:
|
20 |
+
text = prompt + sent
|
21 |
+
out = pipeline(text)
|
22 |
+
predictions.append(out[0]['summary_text'])
|
23 |
+
return {"original": sentences, "compressed": compressed, "predictions": predictions}
|
24 |
+
|
25 |
+
def compute_performace(eval_data):
|
26 |
+
original_compressed = eval_data['compressed']
|
27 |
+
pred_compressed = eval_data['predictions']
|
28 |
+
rouge = evaluate.load('rouge')
|
29 |
+
predictions = eval_data['predictions']#.tolist()
|
30 |
+
references = eval_data['compressed']#.tolist()
|
31 |
+
# Compute the ROUGE score
|
32 |
+
results = rouge.compute(predictions=predictions, references=references)
|
33 |
+
print(results)
|
34 |
+
return results
|
35 |
+
|
36 |
+
def get_latest_checkpoint(checkpoint_dir):
|
37 |
+
subdirs = [name for name in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, name)) and name.startswith("checkpoint-")]
|
38 |
+
checkpoint_numbers = [int(subdir.split("-")[1]) for subdir in subdirs]
|
39 |
+
latest_checkpoint = "checkpoint-" + str(max(checkpoint_numbers))
|
40 |
+
return latest_checkpoint
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
config = yaml.safe_load(open("config.yaml", "r"))
|
44 |
+
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"])
|
45 |
+
data_dir = os.path.join(PROJECT_DIR, config["SENTENCE_COMPRESSION"]["DATA"]["CLEAN_DATA"])
|
46 |
+
model_checkpoint = config["SENTENCE_COMPRESSION"]["INFERENCE"]["MODEL_PATH"]
|
47 |
+
latest_checkpoint = get_latest_checkpoint(os.path.join(PROJECT_DIR, model_checkpoint))
|
48 |
+
model_path = os.path.join(PROJECT_DIR, model_checkpoint, latest_checkpoint)
|
49 |
+
pipeline = load_pipeline(model_path)
|
50 |
+
eval_data = pd.read_csv(os.path.join(data_dir, 'eval_data.csv'))
|
51 |
+
eval_data_res = infernece(pipeline, eval_data)
|
52 |
+
output_dir = os.path.join(PROJECT_DIR, config["SENTENCE_COMPRESSION"]["OUTPUT"]["RESULT"])
|
53 |
+
os.makedirs(output_dir, exist_ok=True)
|
54 |
+
eval_res_df = pd.DataFrame(eval_data_res)
|
55 |
+
eval_res_df.to_csv(os.path.join(output_dir, "eval_result.csv"), index=False)
|
56 |
+
result = compute_performace(eval_data_res)
|
57 |
+
json.dump(result, open(os.path.join(output_dir, "performance.json"), "w"), indent=4)
|
utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
import yaml
|
5 |
+
import seaborn as sns
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from inference import get_latest_checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
def process_loss(loss, final_loss):
|
11 |
+
epoch = int(loss["epoch"])
|
12 |
+
final_loss["epoch"].append(epoch)
|
13 |
+
for key in ["loss", "eval_loss", "eval_rouge1", "eval_rouge2"]:
|
14 |
+
try:
|
15 |
+
value = loss[key]
|
16 |
+
final_loss[key].append(value)
|
17 |
+
except KeyError:
|
18 |
+
pass
|
19 |
+
|
20 |
+
def loss_function(losses):
|
21 |
+
final_loss = {
|
22 |
+
"epoch": [],
|
23 |
+
"loss": [],
|
24 |
+
"eval_loss": [],
|
25 |
+
"eval_rouge1": [],
|
26 |
+
"eval_rouge2": []
|
27 |
+
}
|
28 |
+
for loss_steps in losses:
|
29 |
+
if float(loss_steps.get("epoch", 0)) % 1 == 0:
|
30 |
+
process_loss(loss_steps, final_loss)
|
31 |
+
final_loss["epoch"] = list(set(final_loss["epoch"]))
|
32 |
+
return final_loss
|
33 |
+
|
34 |
+
def plot_loss(data, output_dir):
|
35 |
+
df = pd.DataFrame(data)
|
36 |
+
df_melted = pd.melt(df, id_vars=['epoch'], var_name='metric', value_name='value')
|
37 |
+
plt.figure(figsize=(10, 6))
|
38 |
+
sns.lineplot(data=df_melted, x='epoch', y='value', hue='metric', marker='o')
|
39 |
+
plt.legend(title='Metric')
|
40 |
+
plt.xlabel('Epoch')
|
41 |
+
plt.ylabel('Value')
|
42 |
+
plt.title('Metrics vs Epoch')
|
43 |
+
plt.savefig(os.path.join(output_dir, 'metrics_vs_epoch.png'))
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
config = yaml.safe_load(open("config.yaml", "r"))
|
48 |
+
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"])
|
49 |
+
checkpoint_dir = config["SENTENCE_COMPRESSION"]["INFERENCE"]["MODEL_PATH"]
|
50 |
+
latest_checkpoint = get_latest_checkpoint(os.path.join(PROJECT_DIR, checkpoint_dir))
|
51 |
+
logfile_dir = os.path.join(PROJECT_DIR, checkpoint_dir, latest_checkpoint)
|
52 |
+
logfile_path = os.path.join(logfile_dir, "trainer_state.json")
|
53 |
+
logs = json.load(open(logfile_path))
|
54 |
+
final_loss = loss_function(logs["log_history"])
|
55 |
+
output_dir = config["SENTENCE_COMPRESSION"]["OUTPUT"]["RESULT"]
|
56 |
+
os.makedirs(output_dir, exist_ok=True)
|
57 |
+
plot_loss(final_loss, output_dir)
|