|
import hashlib |
|
import os |
|
import sys |
|
import shutil |
|
from functools import wraps, partial |
|
|
|
import pytest |
|
|
|
if os.path.dirname('src') not in sys.path: |
|
sys.path.append('src') |
|
|
|
os.environ['HARD_ASSERTS'] = "1" |
|
|
|
from src.utils import call_subprocess_onetask, makedirs, FakeTokenizer, download_simple, sanitize_filename |
|
|
|
|
|
def get_inf_port(): |
|
if os.getenv('HOST') is not None: |
|
inf_port = os.environ['HOST'].split(':')[-1] |
|
elif os.getenv('GRADIO_SERVER_PORT') is not None: |
|
inf_port = os.environ['GRADIO_SERVER_PORT'] |
|
else: |
|
inf_port = str(7860) |
|
return int(inf_port) |
|
|
|
|
|
def get_inf_server(): |
|
if os.getenv('HOST') is not None: |
|
inf_server = os.environ['HOST'] |
|
elif os.getenv('GRADIO_SERVER_PORT') is not None: |
|
inf_server = "http://localhost:%s" % os.environ['GRADIO_SERVER_PORT'] |
|
else: |
|
raise ValueError("Expect tests to set HOST or GRADIO_SERVER_PORT") |
|
return inf_server |
|
|
|
|
|
def get_mods(): |
|
testtotalmod = int(os.getenv('TESTMODULOTOTAL', '1')) |
|
testmod = int(os.getenv('TESTMODULO', '0')) |
|
return testtotalmod, testmod |
|
|
|
|
|
def do_skip_test(name): |
|
""" |
|
Control if skip test. note that skipping all tests does not fail, doing no tests is what fails |
|
:param name: |
|
:return: |
|
""" |
|
testtotalmod, testmod = get_mods() |
|
return int(get_sha(name), 16) % testtotalmod != testmod |
|
|
|
|
|
def wrap_test_forked(func): |
|
"""Decorate a function to test, call in subprocess""" |
|
|
|
@wraps(func) |
|
def f(*args, **kwargs): |
|
|
|
gradio_port = os.environ['GRADIO_SERVER_PORT'] = os.getenv('GRADIO_SERVER_PORT', str(7860)) |
|
gradio_port = int(gradio_port) |
|
|
|
|
|
os.environ['HOST'] = os.getenv('HOST', "http://localhost:%s" % gradio_port) |
|
|
|
pytest_name = get_test_name() |
|
if do_skip_test(pytest_name): |
|
|
|
pytest.skip("[%s] TEST SKIPPED due to TESTMODULO" % pytest_name) |
|
func_new = partial(call_subprocess_onetask, func, args, kwargs) |
|
return run_test(func_new) |
|
|
|
return f |
|
|
|
|
|
def run_test(func, *args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
|
|
def get_sha(value): |
|
return hashlib.md5(str(value).encode('utf-8')).hexdigest() |
|
|
|
|
|
def get_test_name(): |
|
tn = os.environ['PYTEST_CURRENT_TEST'].split(':')[-1] |
|
tn = "_".join(tn.split(' ')[:-1]) |
|
return sanitize_filename(tn) |
|
|
|
|
|
def make_user_path_test(): |
|
import os |
|
import shutil |
|
user_path = makedirs('user_path_test', use_base=True) |
|
if os.path.isdir(user_path): |
|
shutil.rmtree(user_path) |
|
user_path = makedirs('user_path_test', use_base=True) |
|
db_dir = "db_dir_UserData" |
|
db_dir = makedirs(db_dir, use_base=True) |
|
if os.path.isdir(db_dir): |
|
shutil.rmtree(db_dir) |
|
db_dir = makedirs(db_dir, use_base=True) |
|
shutil.copy('data/pexels-evg-kowalievska-1170986_small.jpg', user_path) |
|
shutil.copy('README.md', user_path) |
|
shutil.copy('docs/FAQ.md', user_path) |
|
return user_path |
|
|
|
|
|
def get_llama(llama_type=3): |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
if llama_type == 1: |
|
file = 'ggml-model-q4_0_7b.bin' |
|
dest = 'models/7B/' |
|
prompt_type = 'plain' |
|
elif llama_type == 2: |
|
file = 'WizardLM-7B-uncensored.ggmlv3.q8_0.bin' |
|
dest = './' |
|
prompt_type = 'wizard2' |
|
elif llama_type == 3: |
|
file = download_simple('https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q6_K.gguf?download=true') |
|
dest = './' |
|
prompt_type = 'llama2' |
|
else: |
|
raise ValueError("unknown llama_type=%s" % llama_type) |
|
|
|
makedirs(dest, exist_ok=True) |
|
full_path = os.path.join(dest, file) |
|
|
|
if not os.path.isfile(full_path): |
|
|
|
token = os.getenv('HUGGING_FACE_HUB_TOKEN', True) |
|
out_path = hf_hub_download('h2oai/ggml', file, token=token, repo_type='model') |
|
|
|
shutil.copy(out_path, dest) |
|
return prompt_type, full_path |
|
|
|
|
|
def kill_weaviate(db_type): |
|
""" |
|
weaviate launches detatched server, which accumulates entries in db, but we want to start freshly |
|
""" |
|
if db_type == 'weaviate': |
|
os.system('pkill --signal 9 -f weaviate-embedded/weaviate') |
|
|
|
|
|
def count_tokens_llm(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', tokenizer=None): |
|
import time |
|
if tokenizer is None: |
|
assert base_model is not None |
|
from transformers import AutoTokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
t0 = time.time() |
|
a = len(tokenizer(prompt)['input_ids']) |
|
print('llm: ', a, time.time() - t0) |
|
return dict(llm=a) |
|
|
|
|
|
def count_tokens(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b'): |
|
tokenizer = FakeTokenizer() |
|
num_tokens = tokenizer.num_tokens_from_string(prompt) |
|
print(num_tokens) |
|
|
|
from transformers import AutoTokenizer |
|
|
|
t = AutoTokenizer.from_pretrained("distilgpt2") |
|
llm_tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
|
from InstructorEmbedding import INSTRUCTOR |
|
emb = INSTRUCTOR('hkunlp/instructor-large') |
|
|
|
import nltk |
|
|
|
|
|
def nltkTokenize(text): |
|
words = nltk.word_tokenize(text) |
|
return words |
|
|
|
|
|
import re |
|
|
|
WORD = re.compile(r'\w+') |
|
|
|
|
|
def regTokenize(text): |
|
words = WORD.findall(text) |
|
return words |
|
|
|
counts = {} |
|
import time |
|
t0 = time.time() |
|
a = len(regTokenize(prompt)) |
|
print('reg: ', a, time.time() - t0) |
|
counts.update(dict(reg=a)) |
|
|
|
t0 = time.time() |
|
a = len(nltkTokenize(prompt)) |
|
print('nltk: ', a, time.time() - t0) |
|
counts.update(dict(nltk=a)) |
|
|
|
t0 = time.time() |
|
a = len(t(prompt)['input_ids']) |
|
print('tiktoken: ', a, time.time() - t0) |
|
counts.update(dict(tiktoken=a)) |
|
|
|
t0 = time.time() |
|
a = len(llm_tokenizer(prompt)['input_ids']) |
|
print('llm: ', a, time.time() - t0) |
|
counts.update(dict(llm=a)) |
|
|
|
t0 = time.time() |
|
a = emb.tokenize([prompt])['input_ids'].shape[1] |
|
print('instructor-large: ', a, time.time() - t0) |
|
counts.update(dict(instructor=a)) |
|
|
|
return counts |
|
|