File size: 2,072 Bytes
f8f4265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

from llmware.prompts import Prompt


def load_rag_benchmark_tester_ds():

    # pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
    from datasets import load_dataset

    ds_name = "llmware/rag_instruct_benchmark_tester"

    dataset = load_dataset(ds_name)

    print("update: loading RAG Benchmark test dataset - ", dataset)

    test_set = []
    for i, samples in enumerate(dataset["train"]):
        test_set.append(samples)

        # to view test set samples
        # print("rag benchmark dataset test samples: ", i, samples)

    return test_set


def run_test(model_name, prompt_list):

    print("\nupdate: Starting RAG Benchmark Inference Test - ", model_name)

    # pull DRAGON / BLING model directly from catalog, e.g., no from_hf=True
    prompter = Prompt().load_model(model_name)

    for i, entries in enumerate(prompt_list):

        prompt = entries["query"]
        context = entries["context"]

        response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)

        print("\nupdate: model inference output - ", i, response["llm_response"])
        print("update: gold_answer              - ", i, entries["answer"])

        fc = prompter.evidence_check_numbers(response)
        sc = prompter.evidence_comparison_stats(response)
        sr = prompter.evidence_check_sources(response)

        print("\nFact-Checking Tools")

        for entries in fc:
            for f, facts in enumerate(entries["fact_check"]):
                print("update: fact check - ", f, facts)

        for entries in sc:
            print("update: comparison stats - ", entries["comparison_stats"])

        for entries in sr:
            for s, sources in enumerate(entries["source_review"]):
                print("update: sources - ", s, sources)

    return 0


if __name__ == "__main__":

    core_test_set = load_rag_benchmark_tester_ds()

    # one of the 7 gpu dragon models
    gpu_model_name = "llmware/dragon-mistral-7b-v0"

    output = run_test(gpu_model_name, core_test_set)