Spaces:
Paused
Paused
new vLLM
Browse files- api_server.py +112 -79
- protocol.py +501 -66
- serving_completion.py +453 -248
- serving_engine.py +143 -76
api_server.py
CHANGED
@@ -1,30 +1,44 @@
|
|
1 |
-
import argparse
|
2 |
import asyncio
|
3 |
-
import
|
|
|
|
|
4 |
from contextlib import asynccontextmanager
|
5 |
-
from
|
6 |
-
from
|
|
|
7 |
import fastapi
|
8 |
import uvicorn
|
9 |
-
from http import HTTPStatus
|
10 |
from fastapi import Request
|
11 |
from fastapi.exceptions import RequestValidationError
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
-
from fastapi.responses import JSONResponse,
|
|
|
|
|
14 |
|
|
|
|
|
15 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
16 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
17 |
-
from vllm.
|
18 |
-
from protocol import
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from vllm.logger import init_logger
|
20 |
-
from
|
21 |
-
from serving_completion import OpenAIServingCompletion
|
22 |
|
23 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
24 |
|
25 |
-
openai_serving_chat: OpenAIServingChat
|
26 |
-
openai_serving_completion: OpenAIServingCompletion
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
@asynccontextmanager
|
@@ -36,7 +50,9 @@ async def lifespan(app: fastapi.FastAPI):
|
|
36 |
await engine.do_log_stats()
|
37 |
|
38 |
if not engine_args.disable_log_stats:
|
39 |
-
asyncio.create_task(_force_log())
|
|
|
|
|
40 |
|
41 |
yield
|
42 |
|
@@ -45,62 +61,15 @@ app = fastapi.FastAPI(lifespan=lifespan)
|
|
45 |
|
46 |
|
47 |
def parse_args():
|
48 |
-
parser =
|
49 |
-
description="vLLM OpenAI-Compatible RESTful API server.")
|
50 |
-
parser.add_argument("--host", type=str, default=None, help="host name")
|
51 |
-
parser.add_argument("--port", type=int, default=8000, help="port number")
|
52 |
-
parser.add_argument("--allow-credentials",
|
53 |
-
action="store_true",
|
54 |
-
help="allow credentials")
|
55 |
-
parser.add_argument("--allowed-origins",
|
56 |
-
type=json.loads,
|
57 |
-
default=["*"],
|
58 |
-
help="allowed origins")
|
59 |
-
parser.add_argument("--allowed-methods",
|
60 |
-
type=json.loads,
|
61 |
-
default=["*"],
|
62 |
-
help="allowed methods")
|
63 |
-
parser.add_argument("--allowed-headers",
|
64 |
-
type=json.loads,
|
65 |
-
default=["*"],
|
66 |
-
help="allowed headers")
|
67 |
-
parser.add_argument("--served-model-name",
|
68 |
-
type=str,
|
69 |
-
default=None,
|
70 |
-
help="The model name used in the API. If not "
|
71 |
-
"specified, the model name will be the same as "
|
72 |
-
"the huggingface name.")
|
73 |
-
parser.add_argument("--chat-template",
|
74 |
-
type=str,
|
75 |
-
default=None,
|
76 |
-
help="The file path to the chat template, "
|
77 |
-
"or the template in single-line form "
|
78 |
-
"for the specified model")
|
79 |
-
parser.add_argument("--response-role",
|
80 |
-
type=str,
|
81 |
-
default="assistant",
|
82 |
-
help="The role name to return if "
|
83 |
-
"`request.add_generation_prompt=true`.")
|
84 |
-
parser.add_argument("--ssl-keyfile",
|
85 |
-
type=str,
|
86 |
-
default=None,
|
87 |
-
help="The file path to the SSL key file")
|
88 |
-
parser.add_argument("--ssl-certfile",
|
89 |
-
type=str,
|
90 |
-
default=None,
|
91 |
-
help="The file path to the SSL cert file")
|
92 |
-
parser.add_argument(
|
93 |
-
"--root-path",
|
94 |
-
type=str,
|
95 |
-
default=None,
|
96 |
-
help="FastAPI root_path when app is behind a path based routing proxy")
|
97 |
-
|
98 |
-
parser = AsyncEngineArgs.add_cli_args(parser)
|
99 |
return parser.parse_args()
|
100 |
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
@app.exception_handler(RequestValidationError)
|
@@ -112,6 +81,7 @@ async def validation_exception_handler(_, exc):
|
|
112 |
@app.get("/health")
|
113 |
async def health() -> Response:
|
114 |
"""Health check."""
|
|
|
115 |
return Response(status_code=200)
|
116 |
|
117 |
|
@@ -121,6 +91,12 @@ async def show_available_models():
|
|
121 |
return JSONResponse(content=models.model_dump())
|
122 |
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
@app.post("/api/v1/chat/completions")
|
125 |
async def create_chat_completion(request: ChatCompletionRequest,
|
126 |
raw_request: Request):
|
@@ -133,6 +109,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
133 |
return StreamingResponse(content=generator,
|
134 |
media_type="text/event-stream")
|
135 |
else:
|
|
|
136 |
return JSONResponse(content=generator.model_dump())
|
137 |
|
138 |
|
@@ -150,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
150 |
return JSONResponse(content=generator.model_dump())
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
if __name__ == "__main__":
|
154 |
args = parse_args()
|
155 |
|
@@ -161,28 +149,73 @@ if __name__ == "__main__":
|
|
161 |
allow_headers=args.allowed_headers,
|
162 |
)
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
if args.served_model_name is not None:
|
167 |
-
|
168 |
else:
|
169 |
-
|
170 |
|
171 |
engine_args = AsyncEngineArgs.from_cli_args(args)
|
172 |
-
engine = AsyncLLMEngine.from_engine_args(
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
args.response_role,
|
|
|
175 |
args.chat_template)
|
176 |
-
openai_serving_completion = OpenAIServingCompletion(
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
app.root_path = args.root_path
|
182 |
uvicorn.run(app,
|
183 |
host=args.host,
|
184 |
port=args.port,
|
185 |
-
log_level=
|
186 |
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
187 |
ssl_keyfile=args.ssl_keyfile,
|
188 |
-
ssl_certfile=args.ssl_certfile
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
+
import importlib
|
3 |
+
import inspect
|
4 |
+
import re
|
5 |
from contextlib import asynccontextmanager
|
6 |
+
from http import HTTPStatus
|
7 |
+
from typing import Optional, Set
|
8 |
+
|
9 |
import fastapi
|
10 |
import uvicorn
|
|
|
11 |
from fastapi import Request
|
12 |
from fastapi.exceptions import RequestValidationError
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
15 |
+
from prometheus_client import make_asgi_app
|
16 |
+
from starlette.routing import Mount
|
17 |
|
18 |
+
import vllm
|
19 |
+
import vllm.envs as envs
|
20 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
21 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
22 |
+
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
23 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
24 |
+
ChatCompletionResponse,
|
25 |
+
CompletionRequest,
|
26 |
+
EmbeddingRequest, ErrorResponse)
|
27 |
+
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
28 |
+
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
29 |
+
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
30 |
from vllm.logger import init_logger
|
31 |
+
from vllm.usage.usage_lib import UsageContext
|
|
|
32 |
|
33 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
34 |
|
35 |
+
openai_serving_chat: OpenAIServingChat
|
36 |
+
openai_serving_completion: OpenAIServingCompletion
|
37 |
+
openai_serving_embedding: OpenAIServingEmbedding
|
38 |
+
|
39 |
+
logger = init_logger('vllm.entrypoints.openai.api_server')
|
40 |
+
|
41 |
+
_running_tasks: Set[asyncio.Task] = set()
|
42 |
|
43 |
|
44 |
@asynccontextmanager
|
|
|
50 |
await engine.do_log_stats()
|
51 |
|
52 |
if not engine_args.disable_log_stats:
|
53 |
+
task = asyncio.create_task(_force_log())
|
54 |
+
_running_tasks.add(task)
|
55 |
+
task.add_done_callback(_running_tasks.remove)
|
56 |
|
57 |
yield
|
58 |
|
|
|
61 |
|
62 |
|
63 |
def parse_args():
|
64 |
+
parser = make_arg_parser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
return parser.parse_args()
|
66 |
|
67 |
|
68 |
+
# Add prometheus asgi middleware to route /metrics requests
|
69 |
+
route = Mount("/metrics", make_asgi_app())
|
70 |
+
# Workaround for 307 Redirect for /metrics
|
71 |
+
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
72 |
+
app.routes.append(route)
|
73 |
|
74 |
|
75 |
@app.exception_handler(RequestValidationError)
|
|
|
81 |
@app.get("/health")
|
82 |
async def health() -> Response:
|
83 |
"""Health check."""
|
84 |
+
await openai_serving_chat.engine.check_health()
|
85 |
return Response(status_code=200)
|
86 |
|
87 |
|
|
|
91 |
return JSONResponse(content=models.model_dump())
|
92 |
|
93 |
|
94 |
+
@app.get("/version")
|
95 |
+
async def show_version():
|
96 |
+
ver = {"version": vllm.__version__}
|
97 |
+
return JSONResponse(content=ver)
|
98 |
+
|
99 |
+
|
100 |
@app.post("/api/v1/chat/completions")
|
101 |
async def create_chat_completion(request: ChatCompletionRequest,
|
102 |
raw_request: Request):
|
|
|
109 |
return StreamingResponse(content=generator,
|
110 |
media_type="text/event-stream")
|
111 |
else:
|
112 |
+
assert isinstance(generator, ChatCompletionResponse)
|
113 |
return JSONResponse(content=generator.model_dump())
|
114 |
|
115 |
|
|
|
127 |
return JSONResponse(content=generator.model_dump())
|
128 |
|
129 |
|
130 |
+
@app.post("/api/v1/embeddings")
|
131 |
+
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
132 |
+
generator = await openai_serving_embedding.create_embedding(
|
133 |
+
request, raw_request)
|
134 |
+
if isinstance(generator, ErrorResponse):
|
135 |
+
return JSONResponse(content=generator.model_dump(),
|
136 |
+
status_code=generator.code)
|
137 |
+
else:
|
138 |
+
return JSONResponse(content=generator.model_dump())
|
139 |
+
|
140 |
+
|
141 |
if __name__ == "__main__":
|
142 |
args = parse_args()
|
143 |
|
|
|
149 |
allow_headers=args.allowed_headers,
|
150 |
)
|
151 |
|
152 |
+
if token := envs.VLLM_API_KEY or args.api_key:
|
153 |
+
|
154 |
+
@app.middleware("http")
|
155 |
+
async def authentication(request: Request, call_next):
|
156 |
+
root_path = "" if args.root_path is None else args.root_path
|
157 |
+
if request.method == "OPTIONS":
|
158 |
+
return await call_next(request)
|
159 |
+
if not request.url.path.startswith(f"{root_path}/v1"):
|
160 |
+
return await call_next(request)
|
161 |
+
if request.headers.get("Authorization") != "Bearer " + token:
|
162 |
+
return JSONResponse(content={"error": "Unauthorized"},
|
163 |
+
status_code=401)
|
164 |
+
return await call_next(request)
|
165 |
+
|
166 |
+
for middleware in args.middleware:
|
167 |
+
module_path, object_name = middleware.rsplit(".", 1)
|
168 |
+
imported = getattr(importlib.import_module(module_path), object_name)
|
169 |
+
if inspect.isclass(imported):
|
170 |
+
app.add_middleware(imported)
|
171 |
+
elif inspect.iscoroutinefunction(imported):
|
172 |
+
app.middleware("http")(imported)
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Invalid middleware {middleware}. "
|
175 |
+
f"Must be a function or a class.")
|
176 |
+
|
177 |
+
logger.info("vLLM API server version %s", vllm.__version__)
|
178 |
+
logger.info("args: %s", args)
|
179 |
|
180 |
if args.served_model_name is not None:
|
181 |
+
served_model_names = args.served_model_name
|
182 |
else:
|
183 |
+
served_model_names = [args.model]
|
184 |
|
185 |
engine_args = AsyncEngineArgs.from_cli_args(args)
|
186 |
+
engine = AsyncLLMEngine.from_engine_args(
|
187 |
+
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
188 |
+
|
189 |
+
event_loop: Optional[asyncio.AbstractEventLoop]
|
190 |
+
try:
|
191 |
+
event_loop = asyncio.get_running_loop()
|
192 |
+
except RuntimeError:
|
193 |
+
event_loop = None
|
194 |
+
|
195 |
+
if event_loop is not None and event_loop.is_running():
|
196 |
+
# If the current is instanced by Ray Serve,
|
197 |
+
# there is already a running event loop
|
198 |
+
model_config = event_loop.run_until_complete(engine.get_model_config())
|
199 |
+
else:
|
200 |
+
# When using single vLLM without engine_use_ray
|
201 |
+
model_config = asyncio.run(engine.get_model_config())
|
202 |
+
|
203 |
+
openai_serving_chat = OpenAIServingChat(engine, model_config,
|
204 |
+
served_model_names,
|
205 |
args.response_role,
|
206 |
+
args.lora_modules,
|
207 |
args.chat_template)
|
208 |
+
openai_serving_completion = OpenAIServingCompletion(
|
209 |
+
engine, model_config, served_model_names, args.lora_modules)
|
210 |
+
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
211 |
+
served_model_names)
|
|
|
212 |
app.root_path = args.root_path
|
213 |
uvicorn.run(app,
|
214 |
host=args.host,
|
215 |
port=args.port,
|
216 |
+
log_level=args.uvicorn_log_level,
|
217 |
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
218 |
ssl_keyfile=args.ssl_keyfile,
|
219 |
+
ssl_certfile=args.ssl_certfile,
|
220 |
+
ssl_ca_certs=args.ssl_ca_certs,
|
221 |
+
ssl_cert_reqs=args.ssl_cert_reqs)
|
protocol.py
CHANGED
@@ -1,15 +1,58 @@
|
|
1 |
# Adapted from
|
2 |
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
3 |
import time
|
4 |
-
from typing import Dict, List, Literal, Optional, Union
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
from vllm.
|
9 |
from vllm.sampling_params import SamplingParams
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
-
class ErrorResponse(
|
13 |
object: str = "error"
|
14 |
message: str
|
15 |
type: str
|
@@ -17,7 +60,7 @@ class ErrorResponse(BaseModel):
|
|
17 |
code: int
|
18 |
|
19 |
|
20 |
-
class ModelPermission(
|
21 |
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
22 |
object: str = "model_permission"
|
23 |
created: int = Field(default_factory=lambda: int(time.time()))
|
@@ -29,57 +72,176 @@ class ModelPermission(BaseModel):
|
|
29 |
allow_fine_tuning: bool = False
|
30 |
organization: str = "*"
|
31 |
group: Optional[str] = None
|
32 |
-
is_blocking:
|
33 |
|
34 |
|
35 |
-
class ModelCard(
|
36 |
id: str
|
37 |
object: str = "model"
|
38 |
created: int = Field(default_factory=lambda: int(time.time()))
|
39 |
owned_by: str = "vllm"
|
40 |
root: Optional[str] = None
|
41 |
parent: Optional[str] = None
|
|
|
42 |
permission: List[ModelPermission] = Field(default_factory=list)
|
43 |
|
44 |
|
45 |
-
class ModelList(
|
46 |
object: str = "list"
|
47 |
data: List[ModelCard] = Field(default_factory=list)
|
48 |
|
49 |
|
50 |
-
class UsageInfo(
|
51 |
prompt_tokens: int = 0
|
52 |
total_tokens: int = 0
|
53 |
completion_tokens: Optional[int] = 0
|
54 |
|
55 |
|
56 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
model: str
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
max_tokens: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
64 |
stream: Optional[bool] = False
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
user: Optional[str] = None
|
69 |
-
|
|
|
70 |
best_of: Optional[int] = None
|
|
|
71 |
top_k: Optional[int] = -1
|
|
|
|
|
|
|
|
|
72 |
ignore_eos: Optional[bool] = False
|
73 |
-
|
74 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
75 |
skip_special_tokens: Optional[bool] = True
|
76 |
spaces_between_special_tokens: Optional[bool] = True
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def to_sampling_params(self) -> SamplingParams:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return SamplingParams(
|
84 |
n=self.n,
|
85 |
presence_penalty=self.presence_penalty,
|
@@ -88,49 +250,173 @@ class ChatCompletionRequest(BaseModel):
|
|
88 |
temperature=self.temperature,
|
89 |
top_p=self.top_p,
|
90 |
min_p=self.min_p,
|
|
|
91 |
stop=self.stop,
|
92 |
stop_token_ids=self.stop_token_ids,
|
93 |
max_tokens=self.max_tokens,
|
|
|
|
|
|
|
94 |
best_of=self.best_of,
|
95 |
top_k=self.top_k,
|
96 |
ignore_eos=self.ignore_eos,
|
97 |
use_beam_search=self.use_beam_search,
|
|
|
98 |
skip_special_tokens=self.skip_special_tokens,
|
99 |
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
|
|
|
|
|
|
100 |
)
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
model: str
|
105 |
-
# a string, array of strings, array of tokens, or array of token arrays
|
106 |
prompt: Union[List[int], List[List[int]], str, List[str]]
|
107 |
-
|
108 |
-
max_tokens: Optional[int] = 16
|
109 |
-
temperature: Optional[float] = 1.0
|
110 |
-
top_p: Optional[float] = 1.0
|
111 |
-
n: Optional[int] = 1
|
112 |
-
stream: Optional[bool] = False
|
113 |
-
logprobs: Optional[int] = None
|
114 |
echo: Optional[bool] = False
|
115 |
-
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
116 |
-
presence_penalty: Optional[float] = 0.0
|
117 |
frequency_penalty: Optional[float] = 0.0
|
118 |
-
best_of: Optional[int] = None
|
119 |
logit_bias: Optional[Dict[str, float]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
user: Optional[str] = None
|
121 |
-
|
122 |
-
|
123 |
-
ignore_eos: Optional[bool] = False
|
124 |
use_beam_search: Optional[bool] = False
|
|
|
|
|
|
|
|
|
|
|
125 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
|
|
|
|
126 |
skip_special_tokens: Optional[bool] = True
|
127 |
spaces_between_special_tokens: Optional[bool] = True
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def to_sampling_params(self):
|
132 |
echo_without_generation = self.echo and self.max_tokens == 0
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
return SamplingParams(
|
135 |
n=self.n,
|
136 |
best_of=self.best_of,
|
@@ -141,33 +427,88 @@ class CompletionRequest(BaseModel):
|
|
141 |
top_p=self.top_p,
|
142 |
top_k=self.top_k,
|
143 |
min_p=self.min_p,
|
|
|
144 |
stop=self.stop,
|
145 |
stop_token_ids=self.stop_token_ids,
|
146 |
ignore_eos=self.ignore_eos,
|
147 |
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
|
|
148 |
logprobs=self.logprobs,
|
149 |
use_beam_search=self.use_beam_search,
|
|
|
150 |
prompt_logprobs=self.logprobs if self.echo else None,
|
151 |
skip_special_tokens=self.skip_special_tokens,
|
152 |
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
class
|
157 |
text_offset: List[int] = Field(default_factory=list)
|
158 |
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
159 |
tokens: List[str] = Field(default_factory=list)
|
160 |
-
top_logprobs: Optional[List[Optional[Dict[
|
161 |
|
162 |
|
163 |
-
class CompletionResponseChoice(
|
164 |
index: int
|
165 |
text: str
|
166 |
-
logprobs: Optional[
|
167 |
-
finish_reason: Optional[
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
172 |
object: str = "text_completion"
|
173 |
created: int = Field(default_factory=lambda: int(time.time()))
|
@@ -176,14 +517,21 @@ class CompletionResponse(BaseModel):
|
|
176 |
usage: UsageInfo
|
177 |
|
178 |
|
179 |
-
class CompletionResponseStreamChoice(
|
180 |
index: int
|
181 |
text: str
|
182 |
-
logprobs: Optional[
|
183 |
-
finish_reason: Optional[
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
188 |
object: str = "text_completion"
|
189 |
created: int = Field(default_factory=lambda: int(time.time()))
|
@@ -192,41 +540,128 @@ class CompletionStreamResponse(BaseModel):
|
|
192 |
usage: Optional[UsageInfo] = Field(default=None)
|
193 |
|
194 |
|
195 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
role: str
|
197 |
content: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
|
200 |
-
class ChatCompletionResponseChoice(
|
201 |
index: int
|
202 |
message: ChatMessage
|
203 |
-
|
|
|
|
|
204 |
|
205 |
|
206 |
-
class ChatCompletionResponse(
|
207 |
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
208 |
-
object:
|
209 |
created: int = Field(default_factory=lambda: int(time.time()))
|
210 |
model: str
|
211 |
choices: List[ChatCompletionResponseChoice]
|
212 |
usage: UsageInfo
|
213 |
|
214 |
|
215 |
-
class DeltaMessage(
|
216 |
role: Optional[str] = None
|
217 |
content: Optional[str] = None
|
|
|
218 |
|
219 |
|
220 |
-
class ChatCompletionResponseStreamChoice(
|
221 |
index: int
|
222 |
delta: DeltaMessage
|
223 |
-
|
|
|
|
|
224 |
|
225 |
|
226 |
-
class ChatCompletionStreamResponse(
|
227 |
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
228 |
-
object:
|
229 |
created: int = Field(default_factory=lambda: int(time.time()))
|
230 |
model: str
|
231 |
choices: List[ChatCompletionResponseStreamChoice]
|
232 |
-
usage: Optional[UsageInfo] = Field(default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Adapted from
|
2 |
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
3 |
import time
|
4 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
6 |
+
import openai.types.chat
|
7 |
+
import torch
|
8 |
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
9 |
+
# pydantic needs the TypedDict from typing_extensions
|
10 |
+
from typing_extensions import Annotated, Required, TypedDict
|
11 |
|
12 |
+
from vllm.pooling_params import PoolingParams
|
13 |
from vllm.sampling_params import SamplingParams
|
14 |
+
from vllm.utils import random_uuid
|
15 |
+
|
16 |
+
|
17 |
+
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
18 |
+
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
19 |
+
|
20 |
+
type: Required[str]
|
21 |
+
"""The type of the content part."""
|
22 |
+
|
23 |
+
|
24 |
+
ChatCompletionContentPartParam = Union[
|
25 |
+
openai.types.chat.ChatCompletionContentPartParam,
|
26 |
+
CustomChatCompletionContentPartParam]
|
27 |
+
|
28 |
+
|
29 |
+
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
30 |
+
"""Enables custom roles in the Chat Completion API."""
|
31 |
+
role: Required[str]
|
32 |
+
"""The role of the message's author."""
|
33 |
+
|
34 |
+
content: Union[str, List[ChatCompletionContentPartParam]]
|
35 |
+
"""The contents of the message."""
|
36 |
+
|
37 |
+
name: str
|
38 |
+
"""An optional name for the participant.
|
39 |
+
|
40 |
+
Provides the model information to differentiate between participants of the
|
41 |
+
same role.
|
42 |
+
"""
|
43 |
+
|
44 |
+
|
45 |
+
ChatCompletionMessageParam = Union[
|
46 |
+
openai.types.chat.ChatCompletionMessageParam,
|
47 |
+
CustomChatCompletionMessageParam]
|
48 |
+
|
49 |
+
|
50 |
+
class OpenAIBaseModel(BaseModel):
|
51 |
+
# OpenAI API does not allow extra fields
|
52 |
+
model_config = ConfigDict(extra="forbid")
|
53 |
|
54 |
|
55 |
+
class ErrorResponse(OpenAIBaseModel):
|
56 |
object: str = "error"
|
57 |
message: str
|
58 |
type: str
|
|
|
60 |
code: int
|
61 |
|
62 |
|
63 |
+
class ModelPermission(OpenAIBaseModel):
|
64 |
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
65 |
object: str = "model_permission"
|
66 |
created: int = Field(default_factory=lambda: int(time.time()))
|
|
|
72 |
allow_fine_tuning: bool = False
|
73 |
organization: str = "*"
|
74 |
group: Optional[str] = None
|
75 |
+
is_blocking: bool = False
|
76 |
|
77 |
|
78 |
+
class ModelCard(OpenAIBaseModel):
|
79 |
id: str
|
80 |
object: str = "model"
|
81 |
created: int = Field(default_factory=lambda: int(time.time()))
|
82 |
owned_by: str = "vllm"
|
83 |
root: Optional[str] = None
|
84 |
parent: Optional[str] = None
|
85 |
+
max_model_len: Optional[int] = None
|
86 |
permission: List[ModelPermission] = Field(default_factory=list)
|
87 |
|
88 |
|
89 |
+
class ModelList(OpenAIBaseModel):
|
90 |
object: str = "list"
|
91 |
data: List[ModelCard] = Field(default_factory=list)
|
92 |
|
93 |
|
94 |
+
class UsageInfo(OpenAIBaseModel):
|
95 |
prompt_tokens: int = 0
|
96 |
total_tokens: int = 0
|
97 |
completion_tokens: Optional[int] = 0
|
98 |
|
99 |
|
100 |
+
class ResponseFormat(OpenAIBaseModel):
|
101 |
+
# type must be "json_object" or "text"
|
102 |
+
type: Literal["text", "json_object"]
|
103 |
+
|
104 |
+
|
105 |
+
class FunctionDefinition(OpenAIBaseModel):
|
106 |
+
name: str
|
107 |
+
description: Optional[str] = None
|
108 |
+
parameters: Optional[Dict[str, Any]] = None
|
109 |
+
|
110 |
+
|
111 |
+
class ChatCompletionToolsParam(OpenAIBaseModel):
|
112 |
+
type: Literal["function"] = "function"
|
113 |
+
function: FunctionDefinition
|
114 |
+
|
115 |
+
|
116 |
+
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
117 |
+
name: str
|
118 |
+
|
119 |
+
|
120 |
+
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
121 |
+
function: ChatCompletionNamedFunction
|
122 |
+
type: Literal["function"] = "function"
|
123 |
+
|
124 |
+
|
125 |
+
class ChatCompletionRequest(OpenAIBaseModel):
|
126 |
+
# Ordered by official OpenAI API documentation
|
127 |
+
# https://platform.openai.com/docs/api-reference/chat/create
|
128 |
+
messages: List[ChatCompletionMessageParam]
|
129 |
model: str
|
130 |
+
frequency_penalty: Optional[float] = 0.0
|
131 |
+
logit_bias: Optional[Dict[str, float]] = None
|
132 |
+
logprobs: Optional[bool] = False
|
133 |
+
top_logprobs: Optional[int] = 0
|
134 |
max_tokens: Optional[int] = None
|
135 |
+
n: Optional[int] = 1
|
136 |
+
presence_penalty: Optional[float] = 0.0
|
137 |
+
response_format: Optional[ResponseFormat] = None
|
138 |
+
seed: Optional[int] = Field(None,
|
139 |
+
ge=torch.iinfo(torch.long).min,
|
140 |
+
le=torch.iinfo(torch.long).max)
|
141 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
142 |
stream: Optional[bool] = False
|
143 |
+
temperature: Optional[float] = 0.7
|
144 |
+
top_p: Optional[float] = 1.0
|
145 |
+
tools: Optional[List[ChatCompletionToolsParam]] = None
|
146 |
+
tool_choice: Optional[Union[Literal["none"],
|
147 |
+
ChatCompletionNamedToolChoiceParam]] = "none"
|
148 |
user: Optional[str] = None
|
149 |
+
|
150 |
+
# doc: begin-chat-completion-sampling-params
|
151 |
best_of: Optional[int] = None
|
152 |
+
use_beam_search: Optional[bool] = False
|
153 |
top_k: Optional[int] = -1
|
154 |
+
min_p: Optional[float] = 0.0
|
155 |
+
repetition_penalty: Optional[float] = 1.0
|
156 |
+
length_penalty: Optional[float] = 1.0
|
157 |
+
early_stopping: Optional[bool] = False
|
158 |
ignore_eos: Optional[bool] = False
|
159 |
+
min_tokens: Optional[int] = 0
|
160 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
161 |
skip_special_tokens: Optional[bool] = True
|
162 |
spaces_between_special_tokens: Optional[bool] = True
|
163 |
+
# doc: end-chat-completion-sampling-params
|
164 |
+
|
165 |
+
# doc: begin-chat-completion-extra-params
|
166 |
+
echo: Optional[bool] = Field(
|
167 |
+
default=False,
|
168 |
+
description=(
|
169 |
+
"If true, the new message will be prepended with the last message "
|
170 |
+
"if they belong to the same role."),
|
171 |
+
)
|
172 |
+
add_generation_prompt: Optional[bool] = Field(
|
173 |
+
default=True,
|
174 |
+
description=
|
175 |
+
("If true, the generation prompt will be added to the chat template. "
|
176 |
+
"This is a parameter used by chat template in tokenizer config of the "
|
177 |
+
"model."),
|
178 |
+
)
|
179 |
+
add_special_tokens: Optional[bool] = Field(
|
180 |
+
default=False,
|
181 |
+
description=(
|
182 |
+
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
183 |
+
"on top of what is added by the chat template. "
|
184 |
+
"For most models, the chat template takes care of adding the "
|
185 |
+
"special tokens so this should be set to False (as is the "
|
186 |
+
"default)."),
|
187 |
+
)
|
188 |
+
include_stop_str_in_output: Optional[bool] = Field(
|
189 |
+
default=False,
|
190 |
+
description=(
|
191 |
+
"Whether to include the stop string in the output. "
|
192 |
+
"This is only applied when the stop or stop_token_ids is set."),
|
193 |
+
)
|
194 |
+
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
195 |
+
default=None,
|
196 |
+
description=("If specified, the output will follow the JSON schema."),
|
197 |
+
)
|
198 |
+
guided_regex: Optional[str] = Field(
|
199 |
+
default=None,
|
200 |
+
description=(
|
201 |
+
"If specified, the output will follow the regex pattern."),
|
202 |
+
)
|
203 |
+
guided_choice: Optional[List[str]] = Field(
|
204 |
+
default=None,
|
205 |
+
description=(
|
206 |
+
"If specified, the output will be exactly one of the choices."),
|
207 |
+
)
|
208 |
+
guided_grammar: Optional[str] = Field(
|
209 |
+
default=None,
|
210 |
+
description=(
|
211 |
+
"If specified, the output will follow the context free grammar."),
|
212 |
+
)
|
213 |
+
guided_decoding_backend: Optional[str] = Field(
|
214 |
+
default=None,
|
215 |
+
description=(
|
216 |
+
"If specified, will override the default guided decoding backend "
|
217 |
+
"of the server for this specific request. If set, must be either "
|
218 |
+
"'outlines' / 'lm-format-enforcer'"))
|
219 |
+
guided_whitespace_pattern: Optional[str] = Field(
|
220 |
+
default=None,
|
221 |
+
description=(
|
222 |
+
"If specified, will override the default whitespace pattern "
|
223 |
+
"for guided json decoding."))
|
224 |
+
|
225 |
+
# doc: end-chat-completion-extra-params
|
226 |
|
227 |
def to_sampling_params(self) -> SamplingParams:
|
228 |
+
# We now allow logprobs being true without top_logrobs.
|
229 |
+
|
230 |
+
logits_processors = None
|
231 |
+
if self.logit_bias:
|
232 |
+
|
233 |
+
def logit_bias_logits_processor(
|
234 |
+
token_ids: List[int],
|
235 |
+
logits: torch.Tensor) -> torch.Tensor:
|
236 |
+
assert self.logit_bias is not None
|
237 |
+
for token_id, bias in self.logit_bias.items():
|
238 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
239 |
+
bias = min(100, max(-100, bias))
|
240 |
+
logits[int(token_id)] += bias
|
241 |
+
return logits
|
242 |
+
|
243 |
+
logits_processors = [logit_bias_logits_processor]
|
244 |
+
|
245 |
return SamplingParams(
|
246 |
n=self.n,
|
247 |
presence_penalty=self.presence_penalty,
|
|
|
250 |
temperature=self.temperature,
|
251 |
top_p=self.top_p,
|
252 |
min_p=self.min_p,
|
253 |
+
seed=self.seed,
|
254 |
stop=self.stop,
|
255 |
stop_token_ids=self.stop_token_ids,
|
256 |
max_tokens=self.max_tokens,
|
257 |
+
min_tokens=self.min_tokens,
|
258 |
+
logprobs=self.top_logprobs if self.logprobs else None,
|
259 |
+
prompt_logprobs=self.top_logprobs if self.echo else None,
|
260 |
best_of=self.best_of,
|
261 |
top_k=self.top_k,
|
262 |
ignore_eos=self.ignore_eos,
|
263 |
use_beam_search=self.use_beam_search,
|
264 |
+
early_stopping=self.early_stopping,
|
265 |
skip_special_tokens=self.skip_special_tokens,
|
266 |
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
267 |
+
include_stop_str_in_output=self.include_stop_str_in_output,
|
268 |
+
length_penalty=self.length_penalty,
|
269 |
+
logits_processors=logits_processors,
|
270 |
)
|
271 |
|
272 |
+
@model_validator(mode="before")
|
273 |
+
@classmethod
|
274 |
+
def check_guided_decoding_count(cls, data):
|
275 |
+
guide_count = sum([
|
276 |
+
"guided_json" in data and data["guided_json"] is not None,
|
277 |
+
"guided_regex" in data and data["guided_regex"] is not None,
|
278 |
+
"guided_choice" in data and data["guided_choice"] is not None
|
279 |
+
])
|
280 |
+
# you can only use one kind of guided decoding
|
281 |
+
if guide_count > 1:
|
282 |
+
raise ValueError(
|
283 |
+
"You can only use one kind of guided decoding "
|
284 |
+
"('guided_json', 'guided_regex' or 'guided_choice').")
|
285 |
+
# you can only either use guided decoding or tools, not both
|
286 |
+
if guide_count > 1 and "tool_choice" in data and data[
|
287 |
+
"tool_choice"] != "none":
|
288 |
+
raise ValueError(
|
289 |
+
"You can only either use guided decoding or tools, not both.")
|
290 |
+
return data
|
291 |
+
|
292 |
+
@model_validator(mode="before")
|
293 |
+
@classmethod
|
294 |
+
def check_tool_choice(cls, data):
|
295 |
+
if "tool_choice" in data and data["tool_choice"] != "none":
|
296 |
+
if not isinstance(data["tool_choice"], dict):
|
297 |
+
raise ValueError("Currently only named tools are supported.")
|
298 |
+
if "tools" not in data or data["tools"] is None:
|
299 |
+
raise ValueError(
|
300 |
+
"When using `tool_choice`, `tools` must be set.")
|
301 |
+
return data
|
302 |
+
|
303 |
+
@model_validator(mode="before")
|
304 |
+
@classmethod
|
305 |
+
def check_logprobs(cls, data):
|
306 |
+
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
307 |
+
if "logprobs" not in data or data["logprobs"] is False:
|
308 |
+
raise ValueError(
|
309 |
+
"when using `top_logprobs`, `logprobs` must be set to true."
|
310 |
+
)
|
311 |
+
elif not 0 <= data["top_logprobs"] <= 20:
|
312 |
+
raise ValueError(
|
313 |
+
"`top_logprobs` must be a value in the interval [0, 20].")
|
314 |
+
return data
|
315 |
+
|
316 |
+
|
317 |
+
class CompletionRequest(OpenAIBaseModel):
|
318 |
+
# Ordered by official OpenAI API documentation
|
319 |
+
# https://platform.openai.com/docs/api-reference/completions/create
|
320 |
model: str
|
|
|
321 |
prompt: Union[List[int], List[List[int]], str, List[str]]
|
322 |
+
best_of: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
echo: Optional[bool] = False
|
|
|
|
|
324 |
frequency_penalty: Optional[float] = 0.0
|
|
|
325 |
logit_bias: Optional[Dict[str, float]] = None
|
326 |
+
logprobs: Optional[int] = None
|
327 |
+
max_tokens: Optional[int] = 16
|
328 |
+
n: int = 1
|
329 |
+
presence_penalty: Optional[float] = 0.0
|
330 |
+
seed: Optional[int] = Field(None,
|
331 |
+
ge=torch.iinfo(torch.long).min,
|
332 |
+
le=torch.iinfo(torch.long).max)
|
333 |
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
334 |
+
stream: Optional[bool] = False
|
335 |
+
suffix: Optional[str] = None
|
336 |
+
temperature: Optional[float] = 1.0
|
337 |
+
top_p: Optional[float] = 1.0
|
338 |
user: Optional[str] = None
|
339 |
+
|
340 |
+
# doc: begin-completion-sampling-params
|
|
|
341 |
use_beam_search: Optional[bool] = False
|
342 |
+
top_k: Optional[int] = -1
|
343 |
+
min_p: Optional[float] = 0.0
|
344 |
+
repetition_penalty: Optional[float] = 1.0
|
345 |
+
length_penalty: Optional[float] = 1.0
|
346 |
+
early_stopping: Optional[bool] = False
|
347 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
348 |
+
ignore_eos: Optional[bool] = False
|
349 |
+
min_tokens: Optional[int] = 0
|
350 |
skip_special_tokens: Optional[bool] = True
|
351 |
spaces_between_special_tokens: Optional[bool] = True
|
352 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
353 |
+
# doc: end-completion-sampling-params
|
354 |
+
|
355 |
+
# doc: begin-completion-extra-params
|
356 |
+
include_stop_str_in_output: Optional[bool] = Field(
|
357 |
+
default=False,
|
358 |
+
description=(
|
359 |
+
"Whether to include the stop string in the output. "
|
360 |
+
"This is only applied when the stop or stop_token_ids is set."),
|
361 |
+
)
|
362 |
+
response_format: Optional[ResponseFormat] = Field(
|
363 |
+
default=None,
|
364 |
+
description=
|
365 |
+
("Similar to chat completion, this parameter specifies the format of "
|
366 |
+
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
367 |
+
"supported."),
|
368 |
+
)
|
369 |
+
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
370 |
+
default=None,
|
371 |
+
description=("If specified, the output will follow the JSON schema."),
|
372 |
+
)
|
373 |
+
guided_regex: Optional[str] = Field(
|
374 |
+
default=None,
|
375 |
+
description=(
|
376 |
+
"If specified, the output will follow the regex pattern."),
|
377 |
+
)
|
378 |
+
guided_choice: Optional[List[str]] = Field(
|
379 |
+
default=None,
|
380 |
+
description=(
|
381 |
+
"If specified, the output will be exactly one of the choices."),
|
382 |
+
)
|
383 |
+
guided_grammar: Optional[str] = Field(
|
384 |
+
default=None,
|
385 |
+
description=(
|
386 |
+
"If specified, the output will follow the context free grammar."),
|
387 |
+
)
|
388 |
+
guided_decoding_backend: Optional[str] = Field(
|
389 |
+
default=None,
|
390 |
+
description=(
|
391 |
+
"If specified, will override the default guided decoding backend "
|
392 |
+
"of the server for this specific request. If set, must be one of "
|
393 |
+
"'outlines' / 'lm-format-enforcer'"))
|
394 |
+
guided_whitespace_pattern: Optional[str] = Field(
|
395 |
+
default=None,
|
396 |
+
description=(
|
397 |
+
"If specified, will override the default whitespace pattern "
|
398 |
+
"for guided json decoding."))
|
399 |
+
|
400 |
+
# doc: end-completion-extra-params
|
401 |
|
402 |
def to_sampling_params(self):
|
403 |
echo_without_generation = self.echo and self.max_tokens == 0
|
404 |
|
405 |
+
logits_processors = None
|
406 |
+
if self.logit_bias:
|
407 |
+
|
408 |
+
def logit_bias_logits_processor(
|
409 |
+
token_ids: List[int],
|
410 |
+
logits: torch.Tensor) -> torch.Tensor:
|
411 |
+
assert self.logit_bias is not None
|
412 |
+
for token_id, bias in self.logit_bias.items():
|
413 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
414 |
+
bias = min(100, max(-100, bias))
|
415 |
+
logits[int(token_id)] += bias
|
416 |
+
return logits
|
417 |
+
|
418 |
+
logits_processors = [logit_bias_logits_processor]
|
419 |
+
|
420 |
return SamplingParams(
|
421 |
n=self.n,
|
422 |
best_of=self.best_of,
|
|
|
427 |
top_p=self.top_p,
|
428 |
top_k=self.top_k,
|
429 |
min_p=self.min_p,
|
430 |
+
seed=self.seed,
|
431 |
stop=self.stop,
|
432 |
stop_token_ids=self.stop_token_ids,
|
433 |
ignore_eos=self.ignore_eos,
|
434 |
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
435 |
+
min_tokens=self.min_tokens,
|
436 |
logprobs=self.logprobs,
|
437 |
use_beam_search=self.use_beam_search,
|
438 |
+
early_stopping=self.early_stopping,
|
439 |
prompt_logprobs=self.logprobs if self.echo else None,
|
440 |
skip_special_tokens=self.skip_special_tokens,
|
441 |
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
442 |
+
include_stop_str_in_output=self.include_stop_str_in_output,
|
443 |
+
length_penalty=self.length_penalty,
|
444 |
+
logits_processors=logits_processors,
|
445 |
+
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
446 |
)
|
447 |
|
448 |
+
@model_validator(mode="before")
|
449 |
+
@classmethod
|
450 |
+
def check_guided_decoding_count(cls, data):
|
451 |
+
guide_count = sum([
|
452 |
+
"guided_json" in data and data["guided_json"] is not None,
|
453 |
+
"guided_regex" in data and data["guided_regex"] is not None,
|
454 |
+
"guided_choice" in data and data["guided_choice"] is not None
|
455 |
+
])
|
456 |
+
if guide_count > 1:
|
457 |
+
raise ValueError(
|
458 |
+
"You can only use one kind of guided decoding "
|
459 |
+
"('guided_json', 'guided_regex' or 'guided_choice').")
|
460 |
+
return data
|
461 |
+
|
462 |
+
@model_validator(mode="before")
|
463 |
+
@classmethod
|
464 |
+
def check_logprobs(cls, data):
|
465 |
+
if "logprobs" in data and data[
|
466 |
+
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
|
467 |
+
raise ValueError(("if passed, `logprobs` must be a value",
|
468 |
+
" in the interval [0, 5]."))
|
469 |
+
return data
|
470 |
+
|
471 |
+
|
472 |
+
class EmbeddingRequest(BaseModel):
|
473 |
+
# Ordered by official OpenAI API documentation
|
474 |
+
# https://platform.openai.com/docs/api-reference/embeddings
|
475 |
+
model: str
|
476 |
+
input: Union[List[int], List[List[int]], str, List[str]]
|
477 |
+
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
478 |
+
dimensions: Optional[int] = None
|
479 |
+
user: Optional[str] = None
|
480 |
+
|
481 |
+
# doc: begin-embedding-pooling-params
|
482 |
+
additional_data: Optional[Any] = None
|
483 |
+
|
484 |
+
# doc: end-embedding-pooling-params
|
485 |
+
|
486 |
+
def to_pooling_params(self):
|
487 |
+
return PoolingParams(additional_data=self.additional_data)
|
488 |
+
|
489 |
|
490 |
+
class CompletionLogProbs(OpenAIBaseModel):
|
491 |
text_offset: List[int] = Field(default_factory=list)
|
492 |
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
493 |
tokens: List[str] = Field(default_factory=list)
|
494 |
+
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
495 |
|
496 |
|
497 |
+
class CompletionResponseChoice(OpenAIBaseModel):
|
498 |
index: int
|
499 |
text: str
|
500 |
+
logprobs: Optional[CompletionLogProbs] = None
|
501 |
+
finish_reason: Optional[str] = None
|
502 |
+
stop_reason: Optional[Union[int, str]] = Field(
|
503 |
+
default=None,
|
504 |
+
description=(
|
505 |
+
"The stop string or token id that caused the completion "
|
506 |
+
"to stop, None if the completion finished for some other reason "
|
507 |
+
"including encountering the EOS token"),
|
508 |
+
)
|
509 |
+
|
510 |
+
|
511 |
+
class CompletionResponse(OpenAIBaseModel):
|
512 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
513 |
object: str = "text_completion"
|
514 |
created: int = Field(default_factory=lambda: int(time.time()))
|
|
|
517 |
usage: UsageInfo
|
518 |
|
519 |
|
520 |
+
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
521 |
index: int
|
522 |
text: str
|
523 |
+
logprobs: Optional[CompletionLogProbs] = None
|
524 |
+
finish_reason: Optional[str] = None
|
525 |
+
stop_reason: Optional[Union[int, str]] = Field(
|
526 |
+
default=None,
|
527 |
+
description=(
|
528 |
+
"The stop string or token id that caused the completion "
|
529 |
+
"to stop, None if the completion finished for some other reason "
|
530 |
+
"including encountering the EOS token"),
|
531 |
+
)
|
532 |
+
|
533 |
+
|
534 |
+
class CompletionStreamResponse(OpenAIBaseModel):
|
535 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
536 |
object: str = "text_completion"
|
537 |
created: int = Field(default_factory=lambda: int(time.time()))
|
|
|
540 |
usage: Optional[UsageInfo] = Field(default=None)
|
541 |
|
542 |
|
543 |
+
class EmbeddingResponseData(BaseModel):
|
544 |
+
index: int
|
545 |
+
object: str = "embedding"
|
546 |
+
embedding: List[float]
|
547 |
+
|
548 |
+
|
549 |
+
class EmbeddingResponse(BaseModel):
|
550 |
+
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
551 |
+
object: str = "list"
|
552 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
553 |
+
model: str
|
554 |
+
data: List[EmbeddingResponseData]
|
555 |
+
usage: UsageInfo
|
556 |
+
|
557 |
+
|
558 |
+
class FunctionCall(OpenAIBaseModel):
|
559 |
+
name: str
|
560 |
+
arguments: str
|
561 |
+
|
562 |
+
|
563 |
+
class ToolCall(OpenAIBaseModel):
|
564 |
+
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
565 |
+
type: Literal["function"] = "function"
|
566 |
+
function: FunctionCall
|
567 |
+
|
568 |
+
|
569 |
+
class ChatMessage(OpenAIBaseModel):
|
570 |
role: str
|
571 |
content: str
|
572 |
+
tool_calls: List[ToolCall] = Field(default_factory=list)
|
573 |
+
|
574 |
+
|
575 |
+
class ChatCompletionLogProb(OpenAIBaseModel):
|
576 |
+
token: str
|
577 |
+
logprob: float = -9999.0
|
578 |
+
bytes: Optional[List[int]] = None
|
579 |
+
|
580 |
+
|
581 |
+
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
582 |
+
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
583 |
+
|
584 |
+
|
585 |
+
class ChatCompletionLogProbs(OpenAIBaseModel):
|
586 |
+
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
587 |
|
588 |
|
589 |
+
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
590 |
index: int
|
591 |
message: ChatMessage
|
592 |
+
logprobs: Optional[ChatCompletionLogProbs] = None
|
593 |
+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
594 |
+
stop_reason: Optional[Union[int, str]] = None
|
595 |
|
596 |
|
597 |
+
class ChatCompletionResponse(OpenAIBaseModel):
|
598 |
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
599 |
+
object: Literal["chat.completion"] = "chat.completion"
|
600 |
created: int = Field(default_factory=lambda: int(time.time()))
|
601 |
model: str
|
602 |
choices: List[ChatCompletionResponseChoice]
|
603 |
usage: UsageInfo
|
604 |
|
605 |
|
606 |
+
class DeltaMessage(OpenAIBaseModel):
|
607 |
role: Optional[str] = None
|
608 |
content: Optional[str] = None
|
609 |
+
tool_calls: List[ToolCall] = Field(default_factory=list)
|
610 |
|
611 |
|
612 |
+
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
613 |
index: int
|
614 |
delta: DeltaMessage
|
615 |
+
logprobs: Optional[ChatCompletionLogProbs] = None
|
616 |
+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
617 |
+
stop_reason: Optional[Union[int, str]] = None
|
618 |
|
619 |
|
620 |
+
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
621 |
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
622 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
623 |
created: int = Field(default_factory=lambda: int(time.time()))
|
624 |
model: str
|
625 |
choices: List[ChatCompletionResponseStreamChoice]
|
626 |
+
usage: Optional[UsageInfo] = Field(default=None)
|
627 |
+
|
628 |
+
|
629 |
+
class BatchRequestInput(OpenAIBaseModel):
|
630 |
+
"""
|
631 |
+
The per-line object of the batch input file.
|
632 |
+
|
633 |
+
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
634 |
+
"""
|
635 |
+
|
636 |
+
# A developer-provided per-request id that will be used to match outputs to
|
637 |
+
# inputs. Must be unique for each request in a batch.
|
638 |
+
custom_id: str
|
639 |
+
|
640 |
+
# The HTTP method to be used for the request. Currently only POST is
|
641 |
+
# supported.
|
642 |
+
method: str
|
643 |
+
|
644 |
+
# The OpenAI API relative URL to be used for the request. Currently
|
645 |
+
# /v1/chat/completions is supported.
|
646 |
+
url: str
|
647 |
+
|
648 |
+
# The parameteters of the request.
|
649 |
+
body: Union[ChatCompletionRequest, ]
|
650 |
+
|
651 |
+
|
652 |
+
class BatchRequestOutput(OpenAIBaseModel):
|
653 |
+
"""
|
654 |
+
The per-line object of the batch output and error files
|
655 |
+
"""
|
656 |
+
|
657 |
+
id: str
|
658 |
+
|
659 |
+
# A developer-provided per-request id that will be used to match outputs to
|
660 |
+
# inputs.
|
661 |
+
custom_id: str
|
662 |
+
|
663 |
+
response: Optional[ChatCompletionResponse]
|
664 |
+
|
665 |
+
# For requests that failed with a non-HTTP error, this will contain more
|
666 |
+
# information on the cause of the failure.
|
667 |
+
error: Optional[Any]
|
serving_completion.py
CHANGED
@@ -1,290 +1,495 @@
|
|
|
|
1 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from fastapi import Request
|
3 |
-
from
|
4 |
-
|
5 |
-
from vllm.
|
6 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
7 |
-
from protocol import (
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
UsageInfo
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
from vllm.outputs import RequestOutput
|
17 |
-
from
|
|
|
18 |
|
19 |
logger = init_logger(__name__)
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
finish_reason=finish_reason,
|
76 |
-
)
|
77 |
-
]).model_dump_json(exclude_unset=True)
|
78 |
-
yield f"data: {response_json}\n\n"
|
79 |
-
|
80 |
-
if output.finish_reason is not None:
|
81 |
-
logprobs = LogProbs() if request.logprobs is not None else None
|
82 |
-
prompt_tokens = len(res.prompt_token_ids)
|
83 |
-
completion_tokens = len(output.token_ids)
|
84 |
-
final_usage = UsageInfo(
|
85 |
-
prompt_tokens=prompt_tokens,
|
86 |
-
completion_tokens=completion_tokens,
|
87 |
-
total_tokens=prompt_tokens + completion_tokens,
|
88 |
-
)
|
89 |
-
response_json = CompletionStreamResponse(
|
90 |
-
id=request_id,
|
91 |
-
created=created_time,
|
92 |
-
model=model_name,
|
93 |
-
choices=[
|
94 |
-
CompletionResponseStreamChoice(
|
95 |
-
index=i,
|
96 |
-
text="",
|
97 |
-
logprobs=logprobs,
|
98 |
-
finish_reason=output.finish_reason,
|
99 |
-
)
|
100 |
-
],
|
101 |
-
usage=final_usage,
|
102 |
-
).model_dump_json(exclude_unset=True)
|
103 |
-
yield f"data: {response_json}\n\n"
|
104 |
-
|
105 |
-
yield "data: [DONE]\n\n"
|
106 |
-
|
107 |
-
|
108 |
-
def parse_prompt_format(prompt) -> tuple[bool, list]:
|
109 |
-
# get the prompt, openai supports the following
|
110 |
-
# "a string, array of strings, array of tokens, or array of token arrays."
|
111 |
-
prompt_is_tokens = False
|
112 |
-
prompts = [prompt] # case 1: a string
|
113 |
-
if isinstance(prompt, list):
|
114 |
-
if len(prompt) == 0:
|
115 |
-
raise ValueError("please provide at least one prompt")
|
116 |
-
elif isinstance(prompt[0], str):
|
117 |
-
prompt_is_tokens = False
|
118 |
-
prompts = prompt # case 2: array of strings
|
119 |
-
elif isinstance(prompt[0], int):
|
120 |
-
prompt_is_tokens = True
|
121 |
-
prompts = [prompt] # case 3: array of tokens
|
122 |
-
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
123 |
-
prompt_is_tokens = True
|
124 |
-
prompts = prompt # case 4: array of token arrays
|
125 |
else:
|
126 |
-
|
127 |
-
"
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
for output in final_res.outputs:
|
143 |
-
if request.logprobs is not None:
|
144 |
-
if not echo_without_generation:
|
145 |
-
token_ids = output.token_ids
|
146 |
-
top_logprobs = output.logprobs
|
147 |
-
if request.echo:
|
148 |
-
token_ids = prompt_token_ids + token_ids
|
149 |
-
top_logprobs = prompt_logprobs + top_logprobs
|
150 |
else:
|
151 |
-
|
152 |
-
top_logprobs = prompt_logprobs
|
153 |
-
logprobs = create_logprobs_fn(
|
154 |
-
token_ids=token_ids,
|
155 |
-
top_logprobs=top_logprobs,
|
156 |
-
num_output_top_logprobs=request.logprobs,
|
157 |
-
)
|
158 |
-
else:
|
159 |
-
logprobs = None
|
160 |
-
if not echo_without_generation:
|
161 |
-
output_text = output.text
|
162 |
-
if request.echo:
|
163 |
-
output_text = prompt_text + output_text
|
164 |
-
else:
|
165 |
-
output_text = prompt_text
|
166 |
-
choice_data = CompletionResponseChoice(
|
167 |
-
index=output.index,
|
168 |
-
text=output_text,
|
169 |
-
logprobs=logprobs,
|
170 |
-
finish_reason=output.finish_reason,
|
171 |
-
)
|
172 |
-
choices.append(choice_data)
|
173 |
|
174 |
-
|
175 |
-
num_generated_tokens = sum(
|
176 |
-
len(output.token_ids) for output in final_res.outputs)
|
177 |
-
usage = UsageInfo(
|
178 |
-
prompt_tokens=num_prompt_tokens,
|
179 |
-
completion_tokens=num_generated_tokens,
|
180 |
-
total_tokens=num_prompt_tokens + num_generated_tokens,
|
181 |
-
)
|
182 |
|
183 |
-
|
184 |
-
id=request_id,
|
185 |
-
created=created_time,
|
186 |
-
model=model_name,
|
187 |
-
choices=choices,
|
188 |
-
usage=usage,
|
189 |
-
)
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
super().__init__(engine=engine, served_model=served_model)
|
196 |
|
197 |
-
async def
|
198 |
-
|
|
|
|
|
|
|
|
|
199 |
"""Completion API similar to OpenAI's API.
|
200 |
|
201 |
-
See https://platform.openai.com/docs/api-reference/
|
202 |
-
for the API specification. This API mimics the OpenAI
|
|
|
203 |
|
204 |
-
NOTE: Currently we do not support the following
|
205 |
-
-
|
206 |
-
suffix)
|
207 |
-
- logit_bias (to be supported by vLLM engine)
|
208 |
"""
|
209 |
error_check_ret = await self._check_model(request)
|
210 |
if error_check_ret is not None:
|
211 |
return error_check_ret
|
212 |
|
213 |
-
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
214 |
-
echo_without_generation = request.echo and request.max_tokens == 0
|
215 |
-
|
216 |
-
# Return error for unsupported features.
|
217 |
-
if request.suffix is not None:
|
218 |
-
return self.create_error_response(
|
219 |
-
"suffix is not currently supported")
|
220 |
-
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
221 |
-
return self.create_error_response(
|
222 |
-
"logit_bias is not currently supported")
|
223 |
-
|
224 |
-
model_name = request.model
|
225 |
-
request_id = f"cmpl-{random_uuid()}"
|
226 |
-
created_time = int(time.monotonic())
|
227 |
-
|
228 |
-
# Schedule the request and get the result generator.
|
229 |
try:
|
230 |
-
|
231 |
|
232 |
-
|
|
|
233 |
|
234 |
-
|
235 |
-
raise ValueError(
|
236 |
-
"Batching in completion API is not supported.")
|
237 |
-
prompt = prompts[0]
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
except ValueError as e:
|
251 |
return self.create_error_response(str(e))
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
259 |
# Streaming response
|
260 |
-
if stream:
|
261 |
-
return
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
async for res in result_generator:
|
270 |
-
if await raw_request.is_disconnected():
|
271 |
# Abort the request if the client disconnects.
|
272 |
await self.engine.abort(request_id)
|
273 |
return self.create_error_response("Client disconnected")
|
274 |
final_res = res
|
275 |
-
|
276 |
-
final_res, request, echo_without_generation, self._create_logprobs,
|
277 |
-
request_id, created_time, model_name)
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
import time
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
|
5 |
+
Optional)
|
6 |
+
from typing import Sequence as GenericSequence
|
7 |
+
from typing import TypedDict, Union, cast, final
|
8 |
+
|
9 |
from fastapi import Request
|
10 |
+
from openai.types.chat import ChatCompletionContentPartTextParam
|
11 |
+
|
12 |
+
from vllm.config import ModelConfig
|
13 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
14 |
+
from vllm.entrypoints.openai.protocol import (
|
15 |
+
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
16 |
+
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
17 |
+
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
18 |
+
ChatCompletionRequest, ChatCompletionResponse,
|
19 |
+
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
20 |
+
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
21 |
+
FunctionCall, ToolCall, UsageInfo)
|
22 |
+
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
23 |
+
OpenAIServing)
|
24 |
+
from vllm.logger import init_logger
|
25 |
+
from vllm.model_executor.guided_decoding import (
|
26 |
+
get_guided_decoding_logits_processor)
|
27 |
from vllm.outputs import RequestOutput
|
28 |
+
from vllm.sequence import Logprob
|
29 |
+
from vllm.utils import random_uuid
|
30 |
|
31 |
logger = init_logger(__name__)
|
32 |
|
33 |
|
34 |
+
@final # So that it should be compatible with Dict[str, str]
|
35 |
+
class ConversationMessage(TypedDict):
|
36 |
+
role: str
|
37 |
+
content: str
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass(frozen=True)
|
41 |
+
class ChatMessageParseResult:
|
42 |
+
messages: List[ConversationMessage]
|
43 |
+
|
44 |
+
|
45 |
+
class OpenAIServingChat(OpenAIServing):
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
engine: AsyncLLMEngine,
|
49 |
+
model_config: ModelConfig,
|
50 |
+
served_model_names: List[str],
|
51 |
+
response_role: str,
|
52 |
+
lora_modules: Optional[List[LoRAModulePath]] = None,
|
53 |
+
chat_template: Optional[str] = None):
|
54 |
+
super().__init__(engine=engine,
|
55 |
+
model_config=model_config,
|
56 |
+
served_model_names=served_model_names,
|
57 |
+
lora_modules=lora_modules)
|
58 |
+
|
59 |
+
self.response_role = response_role
|
60 |
+
self._load_chat_template(chat_template)
|
61 |
+
|
62 |
+
def _load_chat_template(self, chat_template: Optional[str]):
|
63 |
+
tokenizer = self.tokenizer
|
64 |
+
|
65 |
+
if chat_template is not None:
|
66 |
+
try:
|
67 |
+
with open(chat_template, "r") as f:
|
68 |
+
tokenizer.chat_template = f.read()
|
69 |
+
except OSError as e:
|
70 |
+
JINJA_CHARS = "{}\n"
|
71 |
+
if not any(c in chat_template for c in JINJA_CHARS):
|
72 |
+
msg = (f"The supplied chat template ({chat_template}) "
|
73 |
+
f"looks like a file path, but it failed to be "
|
74 |
+
f"opened. Reason: {e}")
|
75 |
+
raise ValueError(msg) from e
|
76 |
+
|
77 |
+
# If opening a file fails, set chat template to be args to
|
78 |
+
# ensure we decode so our escape are interpreted correctly
|
79 |
+
tokenizer.chat_template = codecs.decode(
|
80 |
+
chat_template, "unicode_escape")
|
81 |
+
|
82 |
+
logger.info("Using supplied chat template:\n%s",
|
83 |
+
tokenizer.chat_template)
|
84 |
+
elif tokenizer.chat_template is not None:
|
85 |
+
logger.info("Using default chat template:\n%s",
|
86 |
+
tokenizer.chat_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
+
logger.warning(
|
89 |
+
"No chat template provided. Chat API will not work.")
|
90 |
+
|
91 |
+
def _parse_chat_message_content_parts(
|
92 |
+
self,
|
93 |
+
role: str,
|
94 |
+
parts: Iterable[ChatCompletionContentPartParam],
|
95 |
+
) -> ChatMessageParseResult:
|
96 |
+
texts: List[str] = []
|
97 |
+
|
98 |
+
for _, part in enumerate(parts):
|
99 |
+
part_type = part["type"]
|
100 |
+
if part_type == "text":
|
101 |
+
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
102 |
+
|
103 |
+
texts.append(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
else:
|
105 |
+
raise NotImplementedError(f"Unknown part type: {part_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
messages = [ConversationMessage(role=role, content="\n".join(texts))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
return ChatMessageParseResult(messages=messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
def _parse_chat_message_content(
|
112 |
+
self,
|
113 |
+
message: ChatCompletionMessageParam,
|
114 |
+
) -> ChatMessageParseResult:
|
115 |
+
role = message["role"]
|
116 |
+
content = message.get("content")
|
117 |
|
118 |
+
if content is None:
|
119 |
+
return ChatMessageParseResult(messages=[])
|
120 |
+
if isinstance(content, str):
|
121 |
+
messages = [ConversationMessage(role=role, content=content)]
|
122 |
+
return ChatMessageParseResult(messages=messages)
|
123 |
|
124 |
+
return self._parse_chat_message_content_parts(role, content)
|
|
|
125 |
|
126 |
+
async def create_chat_completion(
|
127 |
+
self,
|
128 |
+
request: ChatCompletionRequest,
|
129 |
+
raw_request: Optional[Request] = None
|
130 |
+
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
131 |
+
ChatCompletionResponse]:
|
132 |
"""Completion API similar to OpenAI's API.
|
133 |
|
134 |
+
See https://platform.openai.com/docs/api-reference/chat/create
|
135 |
+
for the API specification. This API mimics the OpenAI
|
136 |
+
ChatCompletion API.
|
137 |
|
138 |
+
NOTE: Currently we do not support the following feature:
|
139 |
+
- function_call (Users should implement this by themselves)
|
|
|
|
|
140 |
"""
|
141 |
error_check_ret = await self._check_model(request)
|
142 |
if error_check_ret is not None:
|
143 |
return error_check_ret
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
try:
|
146 |
+
conversation: List[ConversationMessage] = []
|
147 |
|
148 |
+
for msg in request.messages:
|
149 |
+
parsed_msg = self._parse_chat_message_content(msg)
|
150 |
|
151 |
+
conversation.extend(parsed_msg.messages)
|
|
|
|
|
|
|
152 |
|
153 |
+
prompt = self.tokenizer.apply_chat_template(
|
154 |
+
conversation=conversation,
|
155 |
+
tokenize=False,
|
156 |
+
add_generation_prompt=request.add_generation_prompt,
|
157 |
+
)
|
158 |
+
except Exception as e:
|
159 |
+
logger.error("Error in applying chat template from request: %s", e)
|
160 |
+
return self.create_error_response(str(e))
|
161 |
|
162 |
+
request_id = f"cmpl-{random_uuid()}"
|
163 |
+
try:
|
164 |
+
# Tokenize/detokenize depending on prompt format (string/token list)
|
165 |
+
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
166 |
+
request,
|
167 |
+
prompt=prompt,
|
168 |
+
add_special_tokens=request.add_special_tokens)
|
169 |
+
sampling_params = request.to_sampling_params()
|
170 |
+
lora_request = self._maybe_get_lora(request)
|
171 |
+
decoding_config = await self.engine.get_decoding_config()
|
172 |
+
guided_decoding_backend = request.guided_decoding_backend \
|
173 |
+
or decoding_config.guided_decoding_backend
|
174 |
+
guided_decode_logits_processor = (
|
175 |
+
await get_guided_decoding_logits_processor(
|
176 |
+
guided_decoding_backend, request, await
|
177 |
+
self.engine.get_tokenizer()))
|
178 |
+
if guided_decode_logits_processor:
|
179 |
+
if sampling_params.logits_processors is None:
|
180 |
+
sampling_params.logits_processors = []
|
181 |
+
sampling_params.logits_processors.append(
|
182 |
+
guided_decode_logits_processor)
|
183 |
except ValueError as e:
|
184 |
return self.create_error_response(str(e))
|
185 |
|
186 |
+
result_generator = self.engine.generate(
|
187 |
+
{
|
188 |
+
"prompt": prompt_text,
|
189 |
+
"prompt_token_ids": prompt_ids
|
190 |
+
},
|
191 |
+
sampling_params,
|
192 |
+
request_id,
|
193 |
+
lora_request,
|
194 |
+
)
|
195 |
# Streaming response
|
196 |
+
if request.stream:
|
197 |
+
return self.chat_completion_stream_generator(
|
198 |
+
request, result_generator, request_id, conversation)
|
199 |
+
else:
|
200 |
+
try:
|
201 |
+
return await self.chat_completion_full_generator(
|
202 |
+
request, raw_request, result_generator, request_id,
|
203 |
+
conversation)
|
204 |
+
except ValueError as e:
|
205 |
+
# TODO: Use a vllm-specific Validation Error
|
206 |
+
return self.create_error_response(str(e))
|
207 |
+
|
208 |
+
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
209 |
+
if request.add_generation_prompt:
|
210 |
+
return self.response_role
|
211 |
+
else:
|
212 |
+
return request.messages[-1]["role"]
|
213 |
+
|
214 |
+
async def chat_completion_stream_generator(
|
215 |
+
self, request: ChatCompletionRequest,
|
216 |
+
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
217 |
+
conversation: List[ConversationMessage]
|
218 |
+
) -> AsyncGenerator[str, None]:
|
219 |
+
model_name = self.served_model_names[0]
|
220 |
+
created_time = int(time.time())
|
221 |
+
chunk_object_type = "chat.completion.chunk"
|
222 |
+
first_iteration = True
|
223 |
+
|
224 |
+
# Send response for each token for each request.n (index)
|
225 |
+
assert request.n is not None
|
226 |
+
previous_texts = [""] * request.n
|
227 |
+
previous_num_tokens = [0] * request.n
|
228 |
+
finish_reason_sent = [False] * request.n
|
229 |
+
try:
|
230 |
+
async for res in result_generator:
|
231 |
+
# We need to do it here, because if there are exceptions in
|
232 |
+
# the result_generator, it needs to be sent as the FIRST
|
233 |
+
# response (by the try...catch).
|
234 |
+
if first_iteration:
|
235 |
+
# Send first response for each request.n (index) with
|
236 |
+
# the role
|
237 |
+
role = self.get_chat_request_role(request)
|
238 |
+
for i in range(request.n):
|
239 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
240 |
+
index=i,
|
241 |
+
delta=DeltaMessage(role=role),
|
242 |
+
logprobs=None,
|
243 |
+
finish_reason=None)
|
244 |
+
chunk = ChatCompletionStreamResponse(
|
245 |
+
id=request_id,
|
246 |
+
object=chunk_object_type,
|
247 |
+
created=created_time,
|
248 |
+
choices=[choice_data],
|
249 |
+
model=model_name)
|
250 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
251 |
+
yield f"data: {data}\n\n"
|
252 |
+
|
253 |
+
# Send response to echo the input portion of the
|
254 |
+
# last message
|
255 |
+
if request.echo:
|
256 |
+
last_msg_content = ""
|
257 |
+
if conversation and conversation[-1].get(
|
258 |
+
"content") and conversation[-1].get(
|
259 |
+
"role") == role:
|
260 |
+
last_msg_content = conversation[-1]["content"]
|
261 |
+
|
262 |
+
if last_msg_content:
|
263 |
+
for i in range(request.n):
|
264 |
+
choice_data = (
|
265 |
+
ChatCompletionResponseStreamChoice(
|
266 |
+
index=i,
|
267 |
+
delta=DeltaMessage(
|
268 |
+
content=last_msg_content),
|
269 |
+
finish_reason=None))
|
270 |
+
chunk = ChatCompletionStreamResponse(
|
271 |
+
id=request_id,
|
272 |
+
object=chunk_object_type,
|
273 |
+
created=created_time,
|
274 |
+
choices=[choice_data],
|
275 |
+
logprobs=None,
|
276 |
+
model=model_name)
|
277 |
+
data = chunk.model_dump_json(
|
278 |
+
exclude_unset=True)
|
279 |
+
yield f"data: {data}\n\n"
|
280 |
+
first_iteration = False
|
281 |
+
|
282 |
+
for output in res.outputs:
|
283 |
+
i = output.index
|
284 |
+
|
285 |
+
if finish_reason_sent[i]:
|
286 |
+
continue
|
287 |
+
|
288 |
+
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
289 |
+
top_logprobs = output.logprobs[
|
290 |
+
previous_num_tokens[i]:] if output.logprobs else None
|
291 |
+
|
292 |
+
if request.logprobs:
|
293 |
+
logprobs = self._create_chat_logprobs(
|
294 |
+
token_ids=delta_token_ids,
|
295 |
+
top_logprobs=top_logprobs,
|
296 |
+
num_output_top_logprobs=request.top_logprobs,
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
logprobs = None
|
300 |
+
|
301 |
+
delta_text = output.text[len(previous_texts[i]):]
|
302 |
+
previous_texts[i] = output.text
|
303 |
+
previous_num_tokens[i] = len(output.token_ids)
|
304 |
+
|
305 |
+
if request.tool_choice and type(
|
306 |
+
request.tool_choice
|
307 |
+
) is ChatCompletionNamedToolChoiceParam:
|
308 |
+
delta_message = DeltaMessage(tool_calls=[
|
309 |
+
ToolCall(function=FunctionCall(
|
310 |
+
name=request.tool_choice.function.name,
|
311 |
+
arguments=delta_text))
|
312 |
+
])
|
313 |
+
else:
|
314 |
+
delta_message = DeltaMessage(content=delta_text)
|
315 |
+
|
316 |
+
if output.finish_reason is None:
|
317 |
+
# Send token-by-token response for each request.n
|
318 |
+
|
319 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
320 |
+
index=i,
|
321 |
+
delta=delta_message,
|
322 |
+
logprobs=logprobs,
|
323 |
+
finish_reason=None)
|
324 |
+
chunk = ChatCompletionStreamResponse(
|
325 |
+
id=request_id,
|
326 |
+
object=chunk_object_type,
|
327 |
+
created=created_time,
|
328 |
+
choices=[choice_data],
|
329 |
+
model=model_name)
|
330 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
331 |
+
yield f"data: {data}\n\n"
|
332 |
+
else:
|
333 |
+
# Send the finish response for each request.n only once
|
334 |
+
prompt_tokens = len(res.prompt_token_ids)
|
335 |
+
final_usage = UsageInfo(
|
336 |
+
prompt_tokens=prompt_tokens,
|
337 |
+
completion_tokens=previous_num_tokens[i],
|
338 |
+
total_tokens=prompt_tokens +
|
339 |
+
previous_num_tokens[i],
|
340 |
+
)
|
341 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
342 |
+
index=i,
|
343 |
+
delta=delta_message,
|
344 |
+
logprobs=logprobs,
|
345 |
+
finish_reason=output.finish_reason,
|
346 |
+
stop_reason=output.stop_reason)
|
347 |
+
chunk = ChatCompletionStreamResponse(
|
348 |
+
id=request_id,
|
349 |
+
object=chunk_object_type,
|
350 |
+
created=created_time,
|
351 |
+
choices=[choice_data],
|
352 |
+
model=model_name)
|
353 |
+
if final_usage is not None:
|
354 |
+
chunk.usage = final_usage
|
355 |
+
data = chunk.model_dump_json(exclude_unset=True,
|
356 |
+
exclude_none=True)
|
357 |
+
yield f"data: {data}\n\n"
|
358 |
+
finish_reason_sent[i] = True
|
359 |
+
except ValueError as e:
|
360 |
+
# TODO: Use a vllm-specific Validation Error
|
361 |
+
data = self.create_streaming_error_response(str(e))
|
362 |
+
yield f"data: {data}\n\n"
|
363 |
+
# Send the final done message after all response.n are finished
|
364 |
+
yield "data: [DONE]\n\n"
|
365 |
+
|
366 |
+
async def chat_completion_full_generator(
|
367 |
+
self, request: ChatCompletionRequest, raw_request: Optional[Request],
|
368 |
+
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
369 |
+
conversation: List[ConversationMessage]
|
370 |
+
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
371 |
+
|
372 |
+
model_name = self.served_model_names[0]
|
373 |
+
created_time = int(time.time())
|
374 |
+
final_res: Optional[RequestOutput] = None
|
375 |
+
|
376 |
async for res in result_generator:
|
377 |
+
if raw_request is not None and await raw_request.is_disconnected():
|
378 |
# Abort the request if the client disconnects.
|
379 |
await self.engine.abort(request_id)
|
380 |
return self.create_error_response("Client disconnected")
|
381 |
final_res = res
|
382 |
+
assert final_res is not None
|
|
|
|
|
383 |
|
384 |
+
choices = []
|
385 |
+
|
386 |
+
role = self.get_chat_request_role(request)
|
387 |
+
for output in final_res.outputs:
|
388 |
+
token_ids = output.token_ids
|
389 |
+
top_logprobs = output.logprobs
|
390 |
|
391 |
+
if request.logprobs:
|
392 |
+
logprobs = self._create_chat_logprobs(
|
393 |
+
token_ids=token_ids,
|
394 |
+
top_logprobs=top_logprobs,
|
395 |
+
num_output_top_logprobs=request.top_logprobs,
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
logprobs = None
|
399 |
|
400 |
+
if request.tool_choice and type(
|
401 |
+
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
402 |
+
message = ChatMessage(
|
403 |
+
role=role,
|
404 |
+
content="",
|
405 |
+
tool_calls=[
|
406 |
+
ToolCall(function=FunctionCall(
|
407 |
+
name=request.tool_choice.function.name,
|
408 |
+
arguments=output.text))
|
409 |
+
])
|
410 |
+
elif not request.tool_choice or request.tool_choice == "none":
|
411 |
+
message = ChatMessage(role=role, content=output.text)
|
412 |
+
|
413 |
+
choice_data = ChatCompletionResponseChoice(
|
414 |
+
index=output.index,
|
415 |
+
message=message,
|
416 |
+
logprobs=logprobs,
|
417 |
+
finish_reason=output.finish_reason,
|
418 |
+
stop_reason=output.stop_reason)
|
419 |
+
choices.append(choice_data)
|
420 |
+
|
421 |
+
if request.echo:
|
422 |
+
last_msg_content = ""
|
423 |
+
if conversation and conversation[-1].get(
|
424 |
+
"content") and conversation[-1].get("role") == role:
|
425 |
+
last_msg_content = conversation[-1]["content"]
|
426 |
+
|
427 |
+
for choice in choices:
|
428 |
+
full_message = last_msg_content + choice.message.content
|
429 |
+
choice.message.content = full_message
|
430 |
+
|
431 |
+
num_prompt_tokens = len(final_res.prompt_token_ids)
|
432 |
+
num_generated_tokens = sum(
|
433 |
+
len(output.token_ids) for output in final_res.outputs)
|
434 |
+
usage = UsageInfo(
|
435 |
+
prompt_tokens=num_prompt_tokens,
|
436 |
+
completion_tokens=num_generated_tokens,
|
437 |
+
total_tokens=num_prompt_tokens + num_generated_tokens,
|
438 |
+
)
|
439 |
+
response = ChatCompletionResponse(
|
440 |
+
id=request_id,
|
441 |
+
created=created_time,
|
442 |
+
model=model_name,
|
443 |
+
choices=choices,
|
444 |
+
usage=usage,
|
445 |
+
)
|
446 |
|
447 |
+
return response
|
448 |
+
|
449 |
+
def _get_top_logprobs(
|
450 |
+
self, logprobs: Dict[int, Logprob],
|
451 |
+
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
452 |
+
return [
|
453 |
+
ChatCompletionLogProb(
|
454 |
+
token=self._get_decoded_token(p[1], p[0]),
|
455 |
+
logprob=max(p[1].logprob, -9999.0),
|
456 |
+
bytes=list(
|
457 |
+
self._get_decoded_token(p[1],
|
458 |
+
p[0]).encode("utf-8",
|
459 |
+
errors="replace")))
|
460 |
+
for i, p in enumerate(logprobs.items())
|
461 |
+
if top_logprobs and i < top_logprobs
|
462 |
+
]
|
463 |
+
|
464 |
+
def _create_chat_logprobs(
|
465 |
+
self,
|
466 |
+
token_ids: GenericSequence[int],
|
467 |
+
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
468 |
+
num_output_top_logprobs: Optional[int] = None,
|
469 |
+
) -> ChatCompletionLogProbs:
|
470 |
+
"""Create OpenAI-style logprobs."""
|
471 |
+
|
472 |
+
logprobs_content = []
|
473 |
+
|
474 |
+
for i, token_id in enumerate(token_ids):
|
475 |
+
step_top_logprobs = top_logprobs[i]
|
476 |
+
if step_top_logprobs is None:
|
477 |
+
logprobs_content.append(
|
478 |
+
ChatCompletionLogProbsContent(
|
479 |
+
token=self.tokenizer.decode(token_id),
|
480 |
+
bytes=list(
|
481 |
+
self.tokenizer.decode(token_id).encode(
|
482 |
+
"utf-8", errors="replace"))))
|
483 |
+
else:
|
484 |
+
logprobs_content.append(
|
485 |
+
ChatCompletionLogProbsContent(
|
486 |
+
token=step_top_logprobs[token_id].decoded_token,
|
487 |
+
logprob=max(step_top_logprobs[token_id].logprob,
|
488 |
+
-9999.0),
|
489 |
+
bytes=list(
|
490 |
+
step_top_logprobs[token_id].decoded_token.encode(
|
491 |
+
"utf-8", errors="replace")),
|
492 |
+
top_logprobs=self._get_top_logprobs(
|
493 |
+
step_top_logprobs, num_output_top_logprobs)))
|
494 |
+
|
495 |
+
return ChatCompletionLogProbs(content=logprobs_content)
|
serving_engine.py
CHANGED
@@ -1,92 +1,81 @@
|
|
1 |
-
import
|
|
|
2 |
from http import HTTPStatus
|
3 |
-
from typing import Dict, List, Optional, Union
|
4 |
-
|
5 |
-
from
|
|
|
|
|
|
|
6 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
7 |
-
from protocol import (
|
8 |
-
|
9 |
-
|
10 |
ModelCard, ModelList,
|
11 |
ModelPermission)
|
|
|
|
|
|
|
|
|
12 |
|
13 |
logger = init_logger(__name__)
|
14 |
|
15 |
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
-
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
19 |
-
self.engine = engine
|
20 |
-
self.served_model = served_model
|
21 |
-
|
22 |
-
self.max_model_len = 0
|
23 |
-
self.tokenizer = None
|
24 |
|
25 |
-
|
26 |
-
event_loop = asyncio.get_running_loop()
|
27 |
-
except RuntimeError:
|
28 |
-
event_loop = None
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
asyncio.run(self._post_init())
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
self.max_model_len = engine_model_config.max_model_len
|
39 |
|
40 |
# A separate tokenizer to map token IDs to strings.
|
41 |
self.tokenizer = get_tokenizer(
|
42 |
-
|
43 |
-
tokenizer_mode=
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
async def show_available_models(self) -> ModelList:
|
47 |
"""Show available models. Right now we only have one model."""
|
48 |
model_cards = [
|
49 |
-
ModelCard(id=
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
permission=[ModelPermission()])
|
|
|
52 |
]
|
|
|
53 |
return ModelList(data=model_cards)
|
54 |
|
55 |
-
def _create_logprobs(
|
56 |
-
self,
|
57 |
-
token_ids: List[int],
|
58 |
-
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
59 |
-
num_output_top_logprobs: Optional[int] = None,
|
60 |
-
initial_text_offset: int = 0,
|
61 |
-
) -> LogProbs:
|
62 |
-
"""Create OpenAI-style logprobs."""
|
63 |
-
logprobs = LogProbs()
|
64 |
-
last_token_len = 0
|
65 |
-
if num_output_top_logprobs:
|
66 |
-
logprobs.top_logprobs = []
|
67 |
-
for i, token_id in enumerate(token_ids):
|
68 |
-
step_top_logprobs = top_logprobs[i]
|
69 |
-
if step_top_logprobs is not None:
|
70 |
-
token_logprob = step_top_logprobs[token_id]
|
71 |
-
else:
|
72 |
-
token_logprob = None
|
73 |
-
token = self.tokenizer.convert_ids_to_tokens(token_id)
|
74 |
-
logprobs.tokens.append(token)
|
75 |
-
logprobs.token_logprobs.append(token_logprob)
|
76 |
-
if len(logprobs.text_offset) == 0:
|
77 |
-
logprobs.text_offset.append(initial_text_offset)
|
78 |
-
else:
|
79 |
-
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
80 |
-
last_token_len)
|
81 |
-
last_token_len = len(token)
|
82 |
-
|
83 |
-
if num_output_top_logprobs:
|
84 |
-
logprobs.top_logprobs.append({
|
85 |
-
self.tokenizer.convert_ids_to_tokens(i): p
|
86 |
-
for i, p in step_top_logprobs.items()
|
87 |
-
} if step_top_logprobs else None)
|
88 |
-
return logprobs
|
89 |
-
|
90 |
def create_error_response(
|
91 |
self,
|
92 |
message: str,
|
@@ -96,38 +85,116 @@ class OpenAIServing:
|
|
96 |
type=err_type,
|
97 |
code=status_code.value)
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
return self.create_error_response(
|
103 |
message=f"The model `{request.model}` does not exist.",
|
104 |
err_type="NotFoundError",
|
105 |
status_code=HTTPStatus.NOT_FOUND)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def _validate_prompt_and_tokenize(
|
108 |
self,
|
109 |
-
request: Union[ChatCompletionRequest, CompletionRequest
|
|
|
110 |
prompt: Optional[str] = None,
|
111 |
-
prompt_ids: Optional[List[int]] = None
|
|
|
|
|
|
|
|
|
112 |
if not (prompt or prompt_ids):
|
113 |
raise ValueError("Either prompt or prompt_ids should be provided.")
|
114 |
if (prompt and prompt_ids):
|
115 |
raise ValueError(
|
116 |
"Only one of prompt or prompt_ids should be provided.")
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
token_num = len(input_ids)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
if request.max_tokens is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
request.max_tokens = self.max_model_len - token_num
|
124 |
|
125 |
if token_num + request.max_tokens > self.max_model_len:
|
126 |
raise ValueError(
|
127 |
-
f"This model's maximum context length is
|
128 |
-
f"However, you requested
|
|
|
129 |
f"({token_num} in the messages, "
|
130 |
f"{request.max_tokens} in the completion). "
|
131 |
f"Please reduce the length of the messages or completion.", )
|
132 |
else:
|
133 |
-
return input_ids
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from dataclasses import dataclass
|
3 |
from http import HTTPStatus
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
from pydantic import Field
|
7 |
+
from typing_extensions import Annotated
|
8 |
+
|
9 |
+
from vllm.config import ModelConfig
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
11 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
12 |
+
CompletionRequest,
|
13 |
+
EmbeddingRequest, ErrorResponse,
|
14 |
ModelCard, ModelList,
|
15 |
ModelPermission)
|
16 |
+
from vllm.logger import init_logger
|
17 |
+
from vllm.lora.request import LoRARequest
|
18 |
+
from vllm.sequence import Logprob
|
19 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
20 |
|
21 |
logger = init_logger(__name__)
|
22 |
|
23 |
|
24 |
+
@dataclass
|
25 |
+
class LoRAModulePath:
|
26 |
+
name: str
|
27 |
+
local_path: str
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
class OpenAIServing:
|
|
|
|
|
|
|
31 |
|
32 |
+
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
33 |
+
served_model_names: List[str],
|
34 |
+
lora_modules: Optional[List[LoRAModulePath]]):
|
35 |
+
super().__init__()
|
|
|
36 |
|
37 |
+
self.engine = engine
|
38 |
+
self.max_model_len = model_config.max_model_len
|
|
|
39 |
|
40 |
# A separate tokenizer to map token IDs to strings.
|
41 |
self.tokenizer = get_tokenizer(
|
42 |
+
model_config.tokenizer,
|
43 |
+
tokenizer_mode=model_config.tokenizer_mode,
|
44 |
+
tokenizer_revision=model_config.tokenizer_revision,
|
45 |
+
trust_remote_code=model_config.trust_remote_code,
|
46 |
+
truncation_side="left")
|
47 |
+
|
48 |
+
self.served_model_names = served_model_names
|
49 |
+
|
50 |
+
if lora_modules is None:
|
51 |
+
self.lora_requests = []
|
52 |
+
else:
|
53 |
+
self.lora_requests = [
|
54 |
+
LoRARequest(
|
55 |
+
lora_name=lora.name,
|
56 |
+
lora_int_id=i,
|
57 |
+
lora_local_path=lora.local_path,
|
58 |
+
) for i, lora in enumerate(lora_modules, start=1)
|
59 |
+
]
|
60 |
|
61 |
async def show_available_models(self) -> ModelList:
|
62 |
"""Show available models. Right now we only have one model."""
|
63 |
model_cards = [
|
64 |
+
ModelCard(id=served_model_name,
|
65 |
+
max_model_len=self.max_model_len,
|
66 |
+
root=self.served_model_names[0],
|
67 |
+
permission=[ModelPermission()])
|
68 |
+
for served_model_name in self.served_model_names
|
69 |
+
]
|
70 |
+
lora_cards = [
|
71 |
+
ModelCard(id=lora.lora_name,
|
72 |
+
root=self.served_model_names[0],
|
73 |
permission=[ModelPermission()])
|
74 |
+
for lora in self.lora_requests
|
75 |
]
|
76 |
+
model_cards.extend(lora_cards)
|
77 |
return ModelList(data=model_cards)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def create_error_response(
|
80 |
self,
|
81 |
message: str,
|
|
|
85 |
type=err_type,
|
86 |
code=status_code.value)
|
87 |
|
88 |
+
def create_streaming_error_response(
|
89 |
+
self,
|
90 |
+
message: str,
|
91 |
+
err_type: str = "BadRequestError",
|
92 |
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
93 |
+
json_str = json.dumps({
|
94 |
+
"error":
|
95 |
+
self.create_error_response(message=message,
|
96 |
+
err_type=err_type,
|
97 |
+
status_code=status_code).model_dump()
|
98 |
+
})
|
99 |
+
return json_str
|
100 |
+
|
101 |
+
async def _check_model(
|
102 |
+
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
103 |
+
EmbeddingRequest]
|
104 |
+
) -> Optional[ErrorResponse]:
|
105 |
+
if request.model in self.served_model_names:
|
106 |
+
return None
|
107 |
+
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
108 |
+
return None
|
109 |
return self.create_error_response(
|
110 |
message=f"The model `{request.model}` does not exist.",
|
111 |
err_type="NotFoundError",
|
112 |
status_code=HTTPStatus.NOT_FOUND)
|
113 |
|
114 |
+
def _maybe_get_lora(
|
115 |
+
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
116 |
+
EmbeddingRequest]
|
117 |
+
) -> Optional[LoRARequest]:
|
118 |
+
if request.model in self.served_model_names:
|
119 |
+
return None
|
120 |
+
for lora in self.lora_requests:
|
121 |
+
if request.model == lora.lora_name:
|
122 |
+
return lora
|
123 |
+
# if _check_model has been called earlier, this will be unreachable
|
124 |
+
raise ValueError(f"The model `{request.model}` does not exist.")
|
125 |
+
|
126 |
def _validate_prompt_and_tokenize(
|
127 |
self,
|
128 |
+
request: Union[ChatCompletionRequest, CompletionRequest,
|
129 |
+
EmbeddingRequest],
|
130 |
prompt: Optional[str] = None,
|
131 |
+
prompt_ids: Optional[List[int]] = None,
|
132 |
+
truncate_prompt_tokens: Optional[Annotated[int,
|
133 |
+
Field(ge=1)]] = None,
|
134 |
+
add_special_tokens: Optional[bool] = True
|
135 |
+
) -> Tuple[List[int], str]:
|
136 |
if not (prompt or prompt_ids):
|
137 |
raise ValueError("Either prompt or prompt_ids should be provided.")
|
138 |
if (prompt and prompt_ids):
|
139 |
raise ValueError(
|
140 |
"Only one of prompt or prompt_ids should be provided.")
|
141 |
|
142 |
+
if prompt_ids is None:
|
143 |
+
# When using OpenAIServingChat for chat completions, for
|
144 |
+
# most models the special tokens (e.g., BOS) have already
|
145 |
+
# been added by the chat template. Therefore, we do not
|
146 |
+
# need to add them again.
|
147 |
+
# Set add_special_tokens to False (by default) to avoid
|
148 |
+
# adding the BOS tokens again.
|
149 |
+
tokenizer_kwargs: Dict[str, Any] = {
|
150 |
+
"add_special_tokens": add_special_tokens
|
151 |
+
}
|
152 |
+
if truncate_prompt_tokens is not None:
|
153 |
+
tokenizer_kwargs.update({
|
154 |
+
"truncation": True,
|
155 |
+
"max_length": truncate_prompt_tokens,
|
156 |
+
})
|
157 |
+
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
158 |
+
elif truncate_prompt_tokens is not None:
|
159 |
+
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
160 |
+
else:
|
161 |
+
input_ids = prompt_ids
|
162 |
+
|
163 |
+
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
164 |
+
prompt_ids)
|
165 |
token_num = len(input_ids)
|
166 |
|
167 |
+
# Note: EmbeddingRequest doesn't have max_tokens
|
168 |
+
if isinstance(request, EmbeddingRequest):
|
169 |
+
if token_num > self.max_model_len:
|
170 |
+
raise ValueError(
|
171 |
+
f"This model's maximum context length is "
|
172 |
+
f"{self.max_model_len} tokens. However, you requested "
|
173 |
+
f"{token_num} tokens in the input for embedding "
|
174 |
+
f"generation. Please reduce the length of the input.", )
|
175 |
+
return input_ids, input_text
|
176 |
+
|
177 |
if request.max_tokens is None:
|
178 |
+
if token_num >= self.max_model_len:
|
179 |
+
raise ValueError(
|
180 |
+
f"This model's maximum context length is "
|
181 |
+
f"{self.max_model_len} tokens. However, you requested "
|
182 |
+
f"{token_num} tokens in the messages, "
|
183 |
+
f"Please reduce the length of the messages.", )
|
184 |
request.max_tokens = self.max_model_len - token_num
|
185 |
|
186 |
if token_num + request.max_tokens > self.max_model_len:
|
187 |
raise ValueError(
|
188 |
+
f"This model's maximum context length is "
|
189 |
+
f"{self.max_model_len} tokens. However, you requested "
|
190 |
+
f"{request.max_tokens + token_num} tokens "
|
191 |
f"({token_num} in the messages, "
|
192 |
f"{request.max_tokens} in the completion). "
|
193 |
f"Please reduce the length of the messages or completion.", )
|
194 |
else:
|
195 |
+
return input_ids, input_text
|
196 |
+
|
197 |
+
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
|
198 |
+
if logprob.decoded_token is not None:
|
199 |
+
return logprob.decoded_token
|
200 |
+
return self.tokenizer.decode(token_id)
|