File size: 6,518 Bytes
3943768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import pytest
from tests.utils import wrap_test_forked, get_llama
from src.enums import DocumentSubset
@wrap_test_forked
def test_cli(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in our solar system" in all_generations[0]
@pytest.mark.parametrize("base_model", ['gptj', 'gpt4all_llama'])
@wrap_test_forked
def test_cli_langchain(base_model, monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model=base_model, cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
# no sources in output now
# assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
assert "looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"what the cat is doing" in all_generations[0] or \
"question about a cat" in all_generations[0] or \
"The prompt asks for an answer to a question" in all_generations[0] or \
"The prompt asks what the cat in the scenario is doing" in all_generations[0] or \
"The prompt asks why H2O.ai" in all_generations[0] or \
"cat is sitting on a window" in all_generations[0] or \
"cat is sitting" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
prompt_type, full_path = get_llama()
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
model_path_llama=full_path,
prompt_type=prompt_type,
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "pexels-evg-kowalievska-1170986_small.jpg" in str(all_sources[0])
assert "the cat is sitting" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"The cat is likely relaxing and enjoying" in all_generations[0] or \
"cat in the image is" in all_generations[0] or \
"cat is sitting on a window sill" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
prompt_type, full_path = get_llama()
query = "Who are you?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
langchain_mode = 'Disabled'
all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode=langchain_mode,
prompt_type=prompt_type,
model_path_llama=full_path,
user_path=None,
langchain_modes=[langchain_mode],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
"how can I assist" in all_generations[0] or \
"am a virtual assistant" in all_generations[0] or \
"My name is John." in all_generations[0] or \
"I am a student" in all_generations[0] or \
"I'm LLaMA" in all_generations[0] or \
"Hello! I'm just an AI assistant" in all_generations[0]
@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False,
score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in the Solar System".lower() in all_generations[0].lower() or \
"The Earth is the third planet".lower() in all_generations[0].lower()
@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"cat is sitting" in all_generations[0]
|