aiben / tests /test_cli.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
import pytest
from tests.utils import wrap_test_forked, get_llama
from src.enums import DocumentSubset
@wrap_test_forked
def test_cli(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in our solar system" in all_generations[0]
@pytest.mark.parametrize("base_model", ['gptj', 'gpt4all_llama'])
@wrap_test_forked
def test_cli_langchain(base_model, monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model=base_model, cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
# no sources in output now
# assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
assert "looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"what the cat is doing" in all_generations[0] or \
"question about a cat" in all_generations[0] or \
"The prompt asks for an answer to a question" in all_generations[0] or \
"The prompt asks what the cat in the scenario is doing" in all_generations[0] or \
"The prompt asks why H2O.ai" in all_generations[0] or \
"cat is sitting on a window" in all_generations[0] or \
"cat is sitting" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
prompt_type, full_path = get_llama()
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
model_path_llama=full_path,
prompt_type=prompt_type,
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "pexels-evg-kowalievska-1170986_small.jpg" in str(all_sources[0])
assert "the cat is sitting" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"The cat is likely relaxing and enjoying" in all_generations[0] or \
"cat in the image is" in all_generations[0] or \
"cat is sitting on a window sill" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
prompt_type, full_path = get_llama()
query = "Who are you?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
langchain_mode = 'Disabled'
all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode=langchain_mode,
prompt_type=prompt_type,
model_path_llama=full_path,
user_path=None,
langchain_modes=[langchain_mode],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
"how can I assist" in all_generations[0] or \
"am a virtual assistant" in all_generations[0] or \
"My name is John." in all_generations[0] or \
"I am a student" in all_generations[0] or \
"I'm LLaMA" in all_generations[0] or \
"Hello! I'm just an AI assistant" in all_generations[0]
@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False,
score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in the Solar System".lower() in all_generations[0].lower() or \
"The Earth is the third planet".lower() in all_generations[0].lower()
@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from src.gen import main
all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
langchain_modes=['UserData', 'MyData'],
document_subset=DocumentSubset.Relevant.name,
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"cat is sitting" in all_generations[0]