File size: 3,475 Bytes
3ccff6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd3b2e0
3ccff6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from datasets import load_dataset
import re


def load_dataset_demo(name):
    try:
        dataset = load_dataset(name)["train"].filter(lambda x: x["flags"])
    except Exception as _:
        dataset = load_dataset(name)["train"]
    return dataset


NAME_DATASETS = [
    "Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01",
    "Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-1.0",
    "Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01",
    "Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-8B-Instruct",
]

DATASETS = {name: load_dataset_demo(name) for name in NAME_DATASETS}
INSTRUCTION_COL = "instruction"
OUTPUT_COL = "output"
OUTPUT_ORIGIN = "output_origin"


def extract_pairs(text):
    # Regex pattern to match <embed>...</embed><passage>...</passage> pairs
    pattern = r"<embed>(.*?)</embed><passage>(.*?)</passage>"
    # Find all matches in the text
    matches = re.findall(pattern, text, re.DOTALL)
    return matches


def preprocess_qa_pairs(text):
    qa_pairs = extract_pairs(text)
    response = ""
    if len(qa_pairs) == 0:
        response = "No query-passage pairs found."
    else:
        for i, (query, passage) in enumerate(qa_pairs):
            response += f"========================== QP-Pair {i+1} =============================\n"
            response += f"Query:\n{query.strip()}\n"
            response += f"Passage:\n{passage.strip()}\n\n"
    return response


def output_fn(dropdown, slider):
    dataset = DATASETS[dropdown]
    example = dataset[int(slider)]
    return (
        example[INSTRUCTION_COL],
        example[OUTPUT_ORIGIN],
        example[OUTPUT_COL],
        preprocess_qa_pairs(example[OUTPUT_COL]),
    )


with gr.Blocks() as demo:
    gr.Markdown("# Explore Self-RAG Datasets")
    with gr.Group():
        with gr.Row():
            with gr.Column():
                dropdown = gr.Dropdown(
                    NAME_DATASETS,
                    multiselect=False,
                    label="Dataset",
                    info="Select the dataset name",
                )
            with gr.Column():
                slider = gr.Slider(
                    minimum=0,
                    maximum=max([len(dataset) for _, dataset in DATASETS.items()]),
                    step=1,
                    label="#example",
                    value=0,
                )
                button = gr.Button(value="Submit", variant="primary")
    with gr.Group():
        with gr.Row():
            output_instruction = gr.Textbox(
                label="Instruction", placeholder="Instruction", type="text"
            )
        with gr.Row():
            with gr.Row():
                output_self_rag = gr.Textbox(
                    label="SELG-RAG output", placeholder="SELG-RAG output", type="text"
                )
                output_self_grit = gr.Textbox(
                    label="SELF-GRIT output",
                    placeholder="SELF-GRIT output",
                    type="text",
                )
    with gr.Group():
        output_qps = gr.Textbox(
            label="Query-Passage Pairs", placeholder="Query-Passage Pairs", type="text"
        )
    button.click(
        fn=output_fn,
        inputs=[dropdown, slider],
        outputs=[output_instruction, output_self_rag, output_self_grit, output_qps],
    )
    demo.launch(share=True, debug=True)