Upload 4 files
Browse filesupload testset inference
- final-testset.ipynb +200 -0
- nrms_model.epoch0.step10001.pth +3 -0
- nrms_model.epoch0.step20001.pth +3 -0
- testset.py +198 -0
final-testset.ipynb
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 35,
|
6 |
+
"id": "0562d8a6-e8e3-4659-ab21-e99d76adcf3c",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [
|
11 |
+
{
|
12 |
+
"name": "stdout",
|
13 |
+
"output_type": "stream",
|
14 |
+
"text": [
|
15 |
+
"35982,9796527,0.12911629676818848\n",
|
16 |
+
"\n",
|
17 |
+
"35982,9796527,0.12911629676818848\n",
|
18 |
+
"\n",
|
19 |
+
"35982,9796527,0.12911629676818848\n",
|
20 |
+
"\n",
|
21 |
+
"35982,9796527,0.12911629676818848\n",
|
22 |
+
"\n",
|
23 |
+
"35982,9796527,0.12911629676818848\n",
|
24 |
+
"\n",
|
25 |
+
"35982,9796527,0.12911629676818848\n",
|
26 |
+
"\n",
|
27 |
+
"35982,9796527,0.12911629676818848\n",
|
28 |
+
"\n",
|
29 |
+
"35982,9796527,0.12911629676818848\n",
|
30 |
+
"\n",
|
31 |
+
"35982,9796527,0.12911629676818848\n",
|
32 |
+
"\n",
|
33 |
+
"35982,9796527,0.12911629676818848\n",
|
34 |
+
"\n"
|
35 |
+
]
|
36 |
+
}
|
37 |
+
],
|
38 |
+
"source": [
|
39 |
+
"for i in range(10):\n",
|
40 |
+
" with open(\"test_set.txt\") as f:\n",
|
41 |
+
" print(f.readline())"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 3,
|
47 |
+
"id": "9e72123e-5a81-4fd1-a07b-f847aee5a590",
|
48 |
+
"metadata": {
|
49 |
+
"tags": []
|
50 |
+
},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"test_behavior_path = \"/work/Blue/ebnerd/ebnerd_testset/test/behaviors.parquet\"\n",
|
54 |
+
"\n",
|
55 |
+
"import polars as pl\n",
|
56 |
+
"\n",
|
57 |
+
"test_behavior_df = pl.read_parquet(test_behavior_path)"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 95,
|
63 |
+
"id": "7c337f1c-8a0e-4a61-9916-0c86887f320e",
|
64 |
+
"metadata": {
|
65 |
+
"tags": []
|
66 |
+
},
|
67 |
+
"outputs": [
|
68 |
+
{
|
69 |
+
"name": "stderr",
|
70 |
+
"output_type": "stream",
|
71 |
+
"text": [
|
72 |
+
"100%|██████████| 13536710/13536710 [18:13<00:00, 12380.33it/s]\n"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"name": "stdout",
|
77 |
+
"output_type": "stream",
|
78 |
+
"text": [
|
79 |
+
"Zipping predictions.txt to predictions.zip\n"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"source": [
|
84 |
+
"from tqdm import tqdm\n",
|
85 |
+
"import numpy as np\n",
|
86 |
+
"from pathlib import Path\n",
|
87 |
+
"import zipfile\n",
|
88 |
+
"\n",
|
89 |
+
"\n",
|
90 |
+
"def transform_list(input_list):\n",
|
91 |
+
" # 입력 리스트를 Numpy 배열로 변환합니다.\n",
|
92 |
+
" arr = np.array(input_list)\n",
|
93 |
+
"\n",
|
94 |
+
" # 내림차순으로 정렬된 인덱스를 가져옵니다.\n",
|
95 |
+
" sorted_indices = np.argsort(-arr)\n",
|
96 |
+
"\n",
|
97 |
+
" # 순위를 매깁니다 (1부터 시작).\n",
|
98 |
+
" ranks = np.empty_like(sorted_indices)\n",
|
99 |
+
" ranks[sorted_indices] = np.arange(1, len(arr) + 1)\n",
|
100 |
+
"\n",
|
101 |
+
" return ranks.tolist()\n",
|
102 |
+
"\n",
|
103 |
+
"def zip_submission_file(\n",
|
104 |
+
" path: Path,\n",
|
105 |
+
" filename_zip: str = None,\n",
|
106 |
+
" verbose: bool = True,\n",
|
107 |
+
" rm_file: bool = True,\n",
|
108 |
+
") -> None:\n",
|
109 |
+
" \"\"\"\n",
|
110 |
+
" Compresses a specified file into a ZIP archive within the same directory.\n",
|
111 |
+
"\n",
|
112 |
+
" Args:\n",
|
113 |
+
" path (Path): The directory path where the file to be zipped and the resulting zip file will be located.\n",
|
114 |
+
" filename_input (str, optional): The name of the file to be compressed. Defaults to the path.name.\n",
|
115 |
+
" filename_zip (str, optional): The name of the output ZIP file. Defaults to \"prediction.zip\".\n",
|
116 |
+
" verbose (bool, optional): If set to True, the function will print the process details. Defaults to True.\n",
|
117 |
+
" rm_file (bool, optional): If set to True, the original file will be removed after compression. Defaults to True.\n",
|
118 |
+
"\n",
|
119 |
+
" Returns:\n",
|
120 |
+
" None: This function does not return any value.\n",
|
121 |
+
" \"\"\"\n",
|
122 |
+
" path = Path(path)\n",
|
123 |
+
" if filename_zip:\n",
|
124 |
+
" path_zip = path.parent.joinpath(filename_zip)\n",
|
125 |
+
" else:\n",
|
126 |
+
" path_zip = path.with_suffix(\".zip\")\n",
|
127 |
+
"\n",
|
128 |
+
" if path_zip.suffix != \".zip\":\n",
|
129 |
+
" raise ValueError(f\"suffix for {path_zip.name} has to be '.zip'\")\n",
|
130 |
+
" if verbose:\n",
|
131 |
+
" print(f\"Zipping {path} to {path_zip}\")\n",
|
132 |
+
" f = zipfile.ZipFile(path_zip, \"w\", zipfile.ZIP_DEFLATED)\n",
|
133 |
+
" f.write(path, arcname=path.name)\n",
|
134 |
+
" f.close()\n",
|
135 |
+
" if rm_file:\n",
|
136 |
+
" path.unlink()\n",
|
137 |
+
"\n",
|
138 |
+
"with open(\"predictions.txt\", 'w') as wf:\n",
|
139 |
+
" with open(\"test_set.txt\", 'r') as f:\n",
|
140 |
+
" behaviors_iter = test_behavior_df.select(\"impression_id\", \"user_id\", \"article_ids_inview\").iter_rows()\n",
|
141 |
+
" index = 0\n",
|
142 |
+
" for data in tqdm(behaviors_iter, total=len(test_behavior_df)):\n",
|
143 |
+
" impression_id = data[0]\n",
|
144 |
+
" user_id = data[1]\n",
|
145 |
+
" article_ids_inview = data[2]\n",
|
146 |
+
"\n",
|
147 |
+
" scores = []\n",
|
148 |
+
"\n",
|
149 |
+
" for article_id in article_ids_inview:\n",
|
150 |
+
" preds = f.readline().split(\",\")\n",
|
151 |
+
"\n",
|
152 |
+
" p_user_id = preds[0]\n",
|
153 |
+
" p_article_id = preds[1]\n",
|
154 |
+
" p_score = preds[2]\n",
|
155 |
+
"\n",
|
156 |
+
" if str(article_id) == str(p_article_id):\n",
|
157 |
+
" scores.append(float(p_score))\n",
|
158 |
+
" else:\n",
|
159 |
+
" print(\"Different 0.0\")\n",
|
160 |
+
" scores.append(float(0.0))\n",
|
161 |
+
"\n",
|
162 |
+
" index_ranked = transform_list(scores)\n",
|
163 |
+
" preds = \"[\" + \",\".join([str(ir) for ir in index_ranked]) + \"]\"\n",
|
164 |
+
"\n",
|
165 |
+
" wf.write(\" \".join([str(impression_id), preds]) + \"\\n\")\n",
|
166 |
+
"\n",
|
167 |
+
"zip_submission_file(path=Path(\"predictions.txt\"), rm_file=True)"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"id": "2d5c1bcc-e4b0-4217-93ec-4ca3e24dc6ab",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": []
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"kernelspec": {
|
181 |
+
"display_name": "blue",
|
182 |
+
"language": "python",
|
183 |
+
"name": "blue"
|
184 |
+
},
|
185 |
+
"language_info": {
|
186 |
+
"codemirror_mode": {
|
187 |
+
"name": "ipython",
|
188 |
+
"version": 3
|
189 |
+
},
|
190 |
+
"file_extension": ".py",
|
191 |
+
"mimetype": "text/x-python",
|
192 |
+
"name": "python",
|
193 |
+
"nbconvert_exporter": "python",
|
194 |
+
"pygments_lexer": "ipython3",
|
195 |
+
"version": "3.10.13"
|
196 |
+
}
|
197 |
+
},
|
198 |
+
"nbformat": 4,
|
199 |
+
"nbformat_minor": 5
|
200 |
+
}
|
nrms_model.epoch0.step10001.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fb936151025cc80032c0b77f62c4c9264dc4bcea4dcbd1cfe9b5ff1c9b2f5c7
|
3 |
+
size 324331194
|
nrms_model.epoch0.step20001.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0785e0dc9d9b39b97988b06bbd2135bc1f1065aab60de50717ada15d4e10e6a4
|
3 |
+
size 324331194
|
testset.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
import os
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import polars as pl
|
13 |
+
import torch
|
14 |
+
from transformers import AutoModel, AutoTokenizer
|
15 |
+
from transformers import Trainer, TrainingArguments
|
16 |
+
from accelerate import Accelerator, DistributedType
|
17 |
+
from torch.optim import AdamW
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from utils._constants import *
|
21 |
+
from utils._nlp import get_transformers_word_embeddings
|
22 |
+
from utils._polars import concat_str_columns, slice_join_dataframes
|
23 |
+
from utils._articles import (
|
24 |
+
convert_text2encoding_with_transformers,
|
25 |
+
create_article_id_to_value_mapping
|
26 |
+
)
|
27 |
+
from utils._python import make_lookup_objects
|
28 |
+
from utils._behaviors import (
|
29 |
+
create_binary_labels_column,
|
30 |
+
sampling_strategy_wu2019,
|
31 |
+
truncate_history,
|
32 |
+
)
|
33 |
+
from utils._articles_behaviors import map_list_article_id_to_value
|
34 |
+
from dataset.pytorch_dataloader import (
|
35 |
+
ebnerd_from_path,
|
36 |
+
NRMSDataset,
|
37 |
+
NewsrecDataset,
|
38 |
+
)
|
39 |
+
from evaluation import (
|
40 |
+
MetricEvaluator,
|
41 |
+
AucScore,
|
42 |
+
NdcgScore,
|
43 |
+
MrrScore,
|
44 |
+
F1Score,
|
45 |
+
LogLossScore,
|
46 |
+
RootMeanSquaredError,
|
47 |
+
AccuracyScore
|
48 |
+
)
|
49 |
+
from models.nrms import NRMSModel
|
50 |
+
|
51 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
52 |
+
|
53 |
+
|
54 |
+
# In[2]:
|
55 |
+
|
56 |
+
|
57 |
+
TEST_DATA_PATH = "merged_0412_final.parquet"
|
58 |
+
|
59 |
+
|
60 |
+
# In[3]:
|
61 |
+
|
62 |
+
|
63 |
+
test_df = pl.read_parquet(TEST_DATA_PATH).with_columns(pl.Series("labels", [[]]))
|
64 |
+
|
65 |
+
|
66 |
+
# In[4]:
|
67 |
+
|
68 |
+
|
69 |
+
from transformers import AutoModel, AutoTokenizer
|
70 |
+
|
71 |
+
model_name = "Maltehb/danish-bert-botxo"
|
72 |
+
model = AutoModel.from_pretrained(model_name)
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
+
word2vec_embeddimg = get_transformers_word_embeddings(model)
|
75 |
+
|
76 |
+
|
77 |
+
# In[5]:
|
78 |
+
|
79 |
+
|
80 |
+
ARTICLES_DATA_PATH = "/work/Blue/ebnerd/ebnerd_testset/articles.parquet"
|
81 |
+
ARTICLE_COLUMNS = [DEFAULT_TITLE_COL, DEFAULT_SUBTITLE_COL]
|
82 |
+
TEXT_MAX_LENGTH = 30
|
83 |
+
|
84 |
+
articles_df = pl.read_parquet(ARTICLES_DATA_PATH)
|
85 |
+
df_articles, cat_col = concat_str_columns(articles_df, columns=ARTICLE_COLUMNS)
|
86 |
+
df_articles, token_col_title = convert_text2encoding_with_transformers(
|
87 |
+
df_articles, tokenizer, cat_col, max_length=TEXT_MAX_LENGTH
|
88 |
+
)
|
89 |
+
article_mapping = create_article_id_to_value_mapping(df=df_articles, value_col=token_col_title)
|
90 |
+
|
91 |
+
|
92 |
+
# In[6]:
|
93 |
+
|
94 |
+
|
95 |
+
from dataclasses import dataclass, field
|
96 |
+
import numpy as np
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class NRMSTestDataset(NewsrecDataset):
|
100 |
+
def __post_init__(self):
|
101 |
+
"""
|
102 |
+
Post-initialization method. Loads the data and sets additional attributes.
|
103 |
+
"""
|
104 |
+
self.lookup_article_index = {id: i for i, id in enumerate(self.article_dict, start=1)}
|
105 |
+
self.lookup_article_matrix = np.array(list(self.article_dict.values()))
|
106 |
+
UNKNOWN_ARRAY = np.zeros(self.lookup_article_matrix.shape[1], dtype=self.lookup_article_matrix.dtype)
|
107 |
+
self.lookup_article_matrix = np.vstack([UNKNOWN_ARRAY, self.lookup_article_matrix])
|
108 |
+
|
109 |
+
self.unknown_index = [0]
|
110 |
+
self.X, self.y = self.load_data()
|
111 |
+
if self.kwargs is not None:
|
112 |
+
self.set_kwargs(self.kwargs)
|
113 |
+
|
114 |
+
def __getitem__(self, idx) -> dict:
|
115 |
+
"""
|
116 |
+
history_input_tensor: (samples, history_size, document_dimension)
|
117 |
+
candidate_input_title: (samples, npratio, document_dimension)
|
118 |
+
label: (samples, npratio)
|
119 |
+
"""
|
120 |
+
batch_X = self.X[idx]
|
121 |
+
article_id_fixed = [self.lookup_article_index.get(f, 0) for f in batch_X["article_id_fixed"].to_list()[0]]
|
122 |
+
history_input_tensor = self.lookup_article_matrix[article_id_fixed]
|
123 |
+
|
124 |
+
article_id_inview = [self.lookup_article_index.get(f, 0) for f in batch_X["article_ids_inview"].to_list()[0]]
|
125 |
+
candidate_input_title = self.lookup_article_matrix[article_id_inview]
|
126 |
+
|
127 |
+
return {
|
128 |
+
"user_id": self.X[idx]["user_id"][0],
|
129 |
+
"history_input_tensor": history_input_tensor,
|
130 |
+
"candidate_article_id" : self.X[idx]["article_ids_inview"][0][0],
|
131 |
+
"candidate_input_title": candidate_input_title,
|
132 |
+
"labels" : np.int32(0)
|
133 |
+
}
|
134 |
+
|
135 |
+
|
136 |
+
# In[7]:
|
137 |
+
|
138 |
+
|
139 |
+
test_dataset = NRMSTestDataset(
|
140 |
+
behaviors=test_df,
|
141 |
+
history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
|
142 |
+
article_dict=article_mapping,
|
143 |
+
unknown_representation="zeros",
|
144 |
+
eval_mode=False,
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
# In[8]:
|
149 |
+
|
150 |
+
|
151 |
+
nrms_model = NRMSModel(
|
152 |
+
pretrained_weight=torch.tensor(word2vec_embeddimg),
|
153 |
+
emb_dim=768,
|
154 |
+
num_heads=16,
|
155 |
+
hidden_dim=128,
|
156 |
+
item_dim=64,
|
157 |
+
)
|
158 |
+
state_dict = torch.load("nrms_model.epoch0.step20001.pth")
|
159 |
+
nrms_model = torch.compile(nrms_model)
|
160 |
+
nrms_model.load_state_dict(state_dict["model"])
|
161 |
+
nrms_model.to("cuda:1")
|
162 |
+
|
163 |
+
|
164 |
+
# In[ ]:
|
165 |
+
|
166 |
+
|
167 |
+
import torch._dynamo
|
168 |
+
from tqdm import tqdm
|
169 |
+
import os
|
170 |
+
from torch.utils.data import DataLoader
|
171 |
+
|
172 |
+
BATCH_SIZE = 256
|
173 |
+
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=60)
|
174 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
175 |
+
torch._dynamo.config.suppress_errors = True
|
176 |
+
|
177 |
+
nrms_model.eval()
|
178 |
+
|
179 |
+
with open("test_set.txt", 'w') as f:
|
180 |
+
with torch.no_grad():
|
181 |
+
for i, batch in enumerate(tqdm(test_dataloader)):
|
182 |
+
user_id = batch["user_id"].cpu().tolist()
|
183 |
+
candidate_article_id = batch["candidate_article_id"].cpu().tolist()
|
184 |
+
history_input_tensor = batch["history_input_tensor"].to("cuda:1")
|
185 |
+
candidate_input_title = batch["candidate_input_title"].to("cuda:1")
|
186 |
+
|
187 |
+
output_logits = nrms_model(history_input_tensor, candidate_input_title, None)[:,0].cpu().tolist()
|
188 |
+
|
189 |
+
for j in range(len(user_id)):
|
190 |
+
line = f"{user_id[j]},{candidate_article_id[j]},{output_logits[j]}\n"
|
191 |
+
f.write(line)
|
192 |
+
|
193 |
+
|
194 |
+
# In[ ]:
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|