File size: 6,518 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import pytest

from tests.utils import wrap_test_forked, get_llama
from src.enums import DocumentSubset


@wrap_test_forked
def test_cli(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in our solar system" in all_generations[0]


@pytest.mark.parametrize("base_model", ['gptj', 'gpt4all_llama'])
@wrap_test_forked
def test_cli_langchain(base_model, monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model=base_model, cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    # no sources in output now
    # assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "what the cat is doing" in all_generations[0] or \
           "question about a cat" in all_generations[0] or \
           "The prompt asks for an answer to a question" in all_generations[0] or \
           "The prompt asks what the cat in the scenario is doing" in all_generations[0] or \
           "The prompt asks why H2O.ai" in all_generations[0] or \
           "cat is sitting on a window" in all_generations[0] or \
           "cat is sitting" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        model_path_llama=full_path,
                                        prompt_type=prompt_type,
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in str(all_sources[0])
    assert "the cat is sitting" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "The cat is likely relaxing and enjoying" in all_generations[0] or \
           "cat in the image is" in all_generations[0] or \
           "cat is sitting on a window sill" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    query = "Who are you?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    langchain_mode = 'Disabled'
    all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                                        langchain_mode=langchain_mode,
                                        prompt_type=prompt_type,
                                        model_path_llama=full_path,
                                        user_path=None,
                                        langchain_modes=[langchain_mode],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
           "how can I assist" in all_generations[0] or \
           "am a virtual assistant" in all_generations[0] or \
           "My name is John." in all_generations[0] or \
           "I am a student" in all_generations[0] or \
           "I'm LLaMA" in all_generations[0] or \
           "Hello! I'm just an AI assistant" in all_generations[0]


@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False,
                                        score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in the Solar System".lower() in all_generations[0].lower() or \
           "The Earth is the third planet".lower() in all_generations[0].lower()


@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
                                        cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "cat is sitting" in all_generations[0]