|
import pytest |
|
from utils import * |
|
|
|
server = ServerPreset.tinyllama2() |
|
|
|
|
|
LONG_TEXT = """ |
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. |
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. |
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. |
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. |
|
""".strip() |
|
|
|
@pytest.fixture(scope="module", autouse=True) |
|
def create_server(): |
|
global server |
|
server = ServerPreset.tinyllama2() |
|
server.n_ctx = 256 |
|
server.n_slots = 2 |
|
|
|
|
|
def test_ctx_shift_enabled(): |
|
|
|
|
|
|
|
|
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/completion", data={ |
|
"n_predict": 64, |
|
"prompt": LONG_TEXT, |
|
}) |
|
assert res.status_code == 200 |
|
assert res.body["timings"]["prompt_n"] == 109 |
|
assert res.body["timings"]["predicted_n"] == 64 |
|
assert res.body["truncated"] is True |
|
|
|
|
|
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ |
|
(64, 64, False), |
|
(-1, 120, True), |
|
]) |
|
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): |
|
global server |
|
server.disable_ctx_shift = True |
|
server.n_predict = -1 |
|
server.start() |
|
res = server.make_request("POST", "/completion", data={ |
|
"n_predict": n_predict, |
|
"prompt": "Hi how are you", |
|
}) |
|
assert res.status_code == 200 |
|
assert res.body["timings"]["predicted_n"] == n_token_output |
|
assert res.body["truncated"] == truncated |
|
|
|
|
|
def test_ctx_shift_disabled_long_prompt(): |
|
global server |
|
server.disable_ctx_shift = True |
|
server.start() |
|
res = server.make_request("POST", "/completion", data={ |
|
"n_predict": 64, |
|
"prompt": LONG_TEXT, |
|
}) |
|
assert res.status_code != 200 |
|
assert "error" in res.body |
|
assert "exceeds the available context size" in res.body["error"]["message"] |
|
|