File size: 2,593 Bytes
a86046b
 
 
 
 
 
 
3c30fa3
a86046b
 
 
 
7b62017
 
 
 
 
 
 
 
a86046b
 
 
 
3c30fa3
 
 
a86046b
 
 
86e673e
7b62017
 
3c30fa3
 
 
 
 
 
 
7b62017
 
 
a86046b
7b62017
 
 
 
3c30fa3
 
 
7b62017
 
a86046b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62017
 
 
 
 
 
 
 
 
 
d131aa3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from functools import partial

import numpy as np
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm

from perplexity_lenses import REGISTRY_DATASET
from perplexity_lenses.perplexity import KenlmModel


def hub_dataset_to_dataframe(
    path: str,
    name: str,
    split: str,
    sample: int,
    text_column: str,
    model: KenlmModel,
    seed: int = 0,
    doc_type: str = "Whole document",
) -> pd.DataFrame:
    load_dataset_fn = partial(load_dataset, path=path)
    if name:
        load_dataset_fn = partial(load_dataset_fn, name=name)
        # Special case for the registry dataset
        if path == REGISTRY_DATASET:
            load_dataset_fn = partial(load_dataset_fn, data_files=f"{name}/*")
    if split:
        load_dataset_fn = partial(load_dataset_fn, split=split)
    dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
    if doc_type.lower() == "sentence":
        dataset = dataset.map(
            lambda x: [
                {
                    text_column: sentence,
                    "perplexity": model.get_perplexity(sentence),
                    "label": x.get("labels", [])[0]
                    if len(x.get("labels", [])) > 0
                    else "NONE",  # Special case for registry dataset
                }
                for sentence in x[text_column].split("\n")
            ]
        )
    else:
        dataset = dataset.map(
            lambda x: {
                text_column: x[text_column],
                "perplexity": model.get_perplexity(x[text_column]),
                "label": x.get("labels", [])[0]
                if len(x.get("labels", [])) > 0
                else "NONE",  # Special case for registry dataset
            }
        )
    instances = []
    count = 0
    for instance in tqdm(dataset, total=sample):
        if isinstance(instance, list):
            for sentence in instance:
                instances.append(sentence)
                count += 1
                if count == sample:
                    break
        else:
            instances.append(instance)
            count += 1
        if count == sample:
            break
    return pd.DataFrame(instances)


def documents_df_to_sentences_df(
    df: pd.DataFrame, text_column: str, sample: int, seed: int = 0
):
    df_sentences = pd.DataFrame(
        {
            text_column: np.array(
                df[text_column].map(lambda x: x.split("\n")).values.tolist()
            ).flatten()
        }
    )
    return df_sentences.sample(min(sample, df_sentences.shape[0]), random_state=seed)