File size: 2,703 Bytes
ee41ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from datetime import datetime

from pathlib import Path

import polars as pl
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import Trainer, TrainingArguments
from accelerate import Accelerator, DistributedType
from torch.optim import AdamW
from torch.utils.data import DataLoader

from utils._constants import *
from utils._nlp import get_transformers_word_embeddings
from utils._polars import concat_str_columns, slice_join_dataframes
from utils._articles import (
    convert_text2encoding_with_transformers,
    create_article_id_to_value_mapping
)
from utils._behaviors import (
    create_binary_labels_column,
    sampling_strategy_wu2019,
    truncate_history,
)
from dataset.pytorch_dataloader import (
    ebnerd_from_path,
    NRMSDataset,
)
from evaluation import (
    MetricEvaluator, 
    AucScore,
    NdcgScore,
    MrrScore,
    F1Score,
    LogLossScore,
    RootMeanSquaredError,
    AccuracyScore
)
from models.nrms import NRMSModel
from datasets import Dataset, DatasetDict
import pyarrow as pa
import pyarrow.parquet as pq
import polars as pl


COLUMNS = ["impression_id", DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL, DEFAULT_INVIEW_ARTICLES_COL]

test_first_df = pl.read_parquet("testset_joined.parquet")
schema = pa.schema([
    ("impression_id", pa.int32()),
    ("user_id", pa.int32()),
    ("article_id_fixed", pa.list_(pa.int32())),
    ("article_ids_inview", pa.list_(pa.int32())),
])
exp_writer = pq.ParquetWriter("merged_0412_final.parquet", schema)
only_writer = pq.ParquetWriter("merged_0412_joined_only.parquet", schema)

for idx, rows in enumerate(test_first_df.select(COLUMNS).iter_slices()):
    print(idx, "\n")
    org_table = pa.Table.from_pandas(rows.to_pandas(), schema=schema)
    only_writer.write_table(org_table)

    df = rows.explode("article_ids_inview").with_columns(pl.col("article_ids_inview").map_elements(lambda x: [x]))
    exp_table = pa.Table.from_pandas(df.to_pandas(), schema=schema)
    exp_writer.write_table(exp_table)

only_writer.close()
exp_writer.close()

del test_first_df
del schema

merged_0412_joined_only_df = Dataset.from_parquet("merged_0412_joined_only.parquet")
ebnerd_testset = DatasetDict({
    "testset": merged_0412_joined_only_df,
})
ebnerd_testset.push_to_hub(
    repo_id="mbhr/EB-NeRD",
    config_name="join_test",
    data_dir="data/join_test",
)

del merged_0412_joined_only_df
del ebnerd_testset

merged_0412_final_df = Dataset.from_parquet("merged_0412_final.parquet")
ebnerd_testset = DatasetDict({
    "testset": merged_0412_final_df,
})
ebnerd_testset.push_to_hub(
    repo_id="mbhr/EB-NeRD",
    config_name="join_test_exp",
    data_dir="data/join_test_exp",
)