nrms / upload_final_test_dataset.py
pko89403's picture
Upload upload_final_test_dataset.py
ee41ae8 verified
raw
history blame
2.7 kB
import os
from datetime import datetime
from pathlib import Path
import polars as pl
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import Trainer, TrainingArguments
from accelerate import Accelerator, DistributedType
from torch.optim import AdamW
from torch.utils.data import DataLoader
from utils._constants import *
from utils._nlp import get_transformers_word_embeddings
from utils._polars import concat_str_columns, slice_join_dataframes
from utils._articles import (
convert_text2encoding_with_transformers,
create_article_id_to_value_mapping
)
from utils._behaviors import (
create_binary_labels_column,
sampling_strategy_wu2019,
truncate_history,
)
from dataset.pytorch_dataloader import (
ebnerd_from_path,
NRMSDataset,
)
from evaluation import (
MetricEvaluator,
AucScore,
NdcgScore,
MrrScore,
F1Score,
LogLossScore,
RootMeanSquaredError,
AccuracyScore
)
from models.nrms import NRMSModel
from datasets import Dataset, DatasetDict
import pyarrow as pa
import pyarrow.parquet as pq
import polars as pl
COLUMNS = ["impression_id", DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL, DEFAULT_INVIEW_ARTICLES_COL]
test_first_df = pl.read_parquet("testset_joined.parquet")
schema = pa.schema([
("impression_id", pa.int32()),
("user_id", pa.int32()),
("article_id_fixed", pa.list_(pa.int32())),
("article_ids_inview", pa.list_(pa.int32())),
])
exp_writer = pq.ParquetWriter("merged_0412_final.parquet", schema)
only_writer = pq.ParquetWriter("merged_0412_joined_only.parquet", schema)
for idx, rows in enumerate(test_first_df.select(COLUMNS).iter_slices()):
print(idx, "\n")
org_table = pa.Table.from_pandas(rows.to_pandas(), schema=schema)
only_writer.write_table(org_table)
df = rows.explode("article_ids_inview").with_columns(pl.col("article_ids_inview").map_elements(lambda x: [x]))
exp_table = pa.Table.from_pandas(df.to_pandas(), schema=schema)
exp_writer.write_table(exp_table)
only_writer.close()
exp_writer.close()
del test_first_df
del schema
merged_0412_joined_only_df = Dataset.from_parquet("merged_0412_joined_only.parquet")
ebnerd_testset = DatasetDict({
"testset": merged_0412_joined_only_df,
})
ebnerd_testset.push_to_hub(
repo_id="mbhr/EB-NeRD",
config_name="join_test",
data_dir="data/join_test",
)
del merged_0412_joined_only_df
del ebnerd_testset
merged_0412_final_df = Dataset.from_parquet("merged_0412_final.parquet")
ebnerd_testset = DatasetDict({
"testset": merged_0412_final_df,
})
ebnerd_testset.push_to_hub(
repo_id="mbhr/EB-NeRD",
config_name="join_test_exp",
data_dir="data/join_test_exp",
)