|
import pytest |
|
from openai import OpenAI |
|
from utils import * |
|
|
|
server = ServerPreset.tinyllama2() |
|
|
|
|
|
@pytest.fixture(scope="module", autouse=True) |
|
def create_server(): |
|
global server |
|
server = ServerPreset.tinyllama2() |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", |
|
[ |
|
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), |
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), |
|
] |
|
) |
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): |
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/chat/completions", data={ |
|
"model": model, |
|
"max_tokens": max_tokens, |
|
"messages": [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
], |
|
}) |
|
assert res.status_code == 200 |
|
assert res.body["usage"]["prompt_tokens"] == n_prompt |
|
assert res.body["usage"]["completion_tokens"] == n_predicted |
|
choice = res.body["choices"][0] |
|
assert "assistant" == choice["message"]["role"] |
|
assert match_regex(re_content, choice["message"]["content"]) |
|
if truncated: |
|
assert choice["finish_reason"] == "length" |
|
else: |
|
assert choice["finish_reason"] == "stop" |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", |
|
[ |
|
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), |
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), |
|
] |
|
) |
|
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): |
|
global server |
|
server.start() |
|
res = server.make_stream_request("POST", "/chat/completions", data={ |
|
"model": model, |
|
"max_tokens": max_tokens, |
|
"messages": [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
], |
|
"stream": True, |
|
}) |
|
content = "" |
|
for data in res: |
|
choice = data["choices"][0] |
|
if choice["finish_reason"] in ["stop", "length"]: |
|
assert data["usage"]["prompt_tokens"] == n_prompt |
|
assert data["usage"]["completion_tokens"] == n_predicted |
|
assert "content" not in choice["delta"] |
|
assert match_regex(re_content, content) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
assert choice["finish_reason"] is None |
|
content += choice["delta"]["content"] |
|
|
|
|
|
def test_chat_completion_with_openai_library(): |
|
global server |
|
server.start() |
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") |
|
res = client.chat.completions.create( |
|
model="gpt-3.5-turbo-instruct", |
|
messages=[ |
|
{"role": "system", "content": "Book"}, |
|
{"role": "user", "content": "What is the best book"}, |
|
], |
|
max_tokens=8, |
|
seed=42, |
|
temperature=0.8, |
|
) |
|
print(res) |
|
assert res.choices[0].finish_reason == "stop" |
|
assert res.choices[0].message.content is not None |
|
assert match_regex("(Suddenly)+", res.choices[0].message.content) |
|
|
|
|
|
@pytest.mark.parametrize("response_format,n_predicted,re_content", [ |
|
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), |
|
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), |
|
({"type": "json_object"}, 10, "(\\{|John)+"), |
|
({"type": "sound"}, 0, None), |
|
|
|
({"type": "json_object", "schema": 123}, 0, None), |
|
({"type": "json_object", "schema": {"type": 123}}, 0, None), |
|
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), |
|
]) |
|
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): |
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/chat/completions", data={ |
|
"max_tokens": n_predicted, |
|
"messages": [ |
|
{"role": "system", "content": "You are a coding assistant."}, |
|
{"role": "user", "content": "Write an example"}, |
|
], |
|
"response_format": response_format, |
|
}) |
|
if re_content is not None: |
|
assert res.status_code == 200 |
|
choice = res.body["choices"][0] |
|
assert match_regex(re_content, choice["message"]["content"]) |
|
else: |
|
assert res.status_code != 200 |
|
assert "error" in res.body |
|
|
|
|
|
@pytest.mark.parametrize("messages", [ |
|
None, |
|
"string", |
|
[123], |
|
[{}], |
|
[{"role": 123}], |
|
[{"role": "system", "content": 123}], |
|
|
|
[{"role": "system", "content": "test"}, {}], |
|
]) |
|
def test_invalid_chat_completion_req(messages): |
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/chat/completions", data={ |
|
"messages": messages, |
|
}) |
|
assert res.status_code == 400 or res.status_code == 500 |
|
assert "error" in res.body |
|
|
|
|
|
def test_chat_completion_with_timings_per_token(): |
|
global server |
|
server.start() |
|
res = server.make_stream_request("POST", "/chat/completions", data={ |
|
"max_tokens": 10, |
|
"messages": [{"role": "user", "content": "test"}], |
|
"stream": True, |
|
"timings_per_token": True, |
|
}) |
|
for data in res: |
|
assert "timings" in data |
|
assert "prompt_per_second" in data["timings"] |
|
assert "predicted_per_second" in data["timings"] |
|
assert "predicted_n" in data["timings"] |
|
assert data["timings"]["predicted_n"] <= 10 |
|
|