Spaces:
Paused
Paused
change to 0.5.3 support llama 3.1
Browse files- Dockerfile +1 -1
- api_server.py +109 -43
- protocol.py +61 -85
- serving_chat.py +119 -220
- serving_completion.py +90 -98
- serving_embedding.py +54 -25
- serving_engine.py +259 -74
- serving_tokenization.py +135 -0
Dockerfile
CHANGED
@@ -14,7 +14,7 @@ RUN pip3 install "torch==2.1.1"
|
|
14 |
# This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
|
15 |
# RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
|
16 |
RUN pip3 install -U openai
|
17 |
-
RUN pip3 install vllm==0.5.
|
18 |
RUN pip3 install -U pydantic
|
19 |
RUN pip3 install -U aioprometheus
|
20 |
|
|
|
14 |
# This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
|
15 |
# RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
|
16 |
RUN pip3 install -U openai
|
17 |
+
RUN pip3 install vllm==0.5.3
|
18 |
RUN pip3 install -U pydantic
|
19 |
RUN pip3 install -U aioprometheus
|
20 |
|
api_server.py
CHANGED
@@ -8,7 +8,7 @@ from typing import Optional, Set
|
|
8 |
|
9 |
import fastapi
|
10 |
import uvicorn
|
11 |
-
from fastapi import Request
|
12 |
from fastapi.exceptions import RequestValidationError
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
@@ -18,6 +18,7 @@ from starlette.routing import Mount
|
|
18 |
import vllm.envs as envs
|
19 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
20 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
21 |
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
22 |
# yapf conflicts with isort for this block
|
23 |
# yapf: disable
|
@@ -33,15 +34,21 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|
33 |
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
34 |
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
35 |
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
|
|
|
36 |
from vllm.logger import init_logger
|
37 |
from vllm.usage.usage_lib import UsageContext
|
|
|
38 |
from vllm.version import __version__ as VLLM_VERSION
|
39 |
|
40 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
41 |
|
|
|
|
|
42 |
openai_serving_chat: OpenAIServingChat
|
43 |
openai_serving_completion: OpenAIServingCompletion
|
44 |
openai_serving_embedding: OpenAIServingEmbedding
|
|
|
45 |
|
46 |
logger = init_logger('vllm.entrypoints.openai.api_server')
|
47 |
|
@@ -64,37 +71,27 @@ async def lifespan(app: fastapi.FastAPI):
|
|
64 |
yield
|
65 |
|
66 |
|
67 |
-
|
68 |
|
69 |
|
70 |
-
def
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
-
|
76 |
-
route = Mount("/metrics", make_asgi_app())
|
77 |
-
# Workaround for 307 Redirect for /metrics
|
78 |
-
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
79 |
-
app.routes.append(route)
|
80 |
-
|
81 |
-
|
82 |
-
@app.exception_handler(RequestValidationError)
|
83 |
-
async def validation_exception_handler(_, exc):
|
84 |
-
err = openai_serving_chat.create_error_response(message=str(exc))
|
85 |
-
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
86 |
-
|
87 |
-
|
88 |
-
@app.get("/health")
|
89 |
async def health() -> Response:
|
90 |
"""Health check."""
|
91 |
await openai_serving_chat.engine.check_health()
|
92 |
return Response(status_code=200)
|
93 |
|
94 |
|
95 |
-
@
|
96 |
async def tokenize(request: TokenizeRequest):
|
97 |
-
generator = await
|
98 |
if isinstance(generator, ErrorResponse):
|
99 |
return JSONResponse(content=generator.model_dump(),
|
100 |
status_code=generator.code)
|
@@ -103,9 +100,9 @@ async def tokenize(request: TokenizeRequest):
|
|
103 |
return JSONResponse(content=generator.model_dump())
|
104 |
|
105 |
|
106 |
-
@
|
107 |
async def detokenize(request: DetokenizeRequest):
|
108 |
-
generator = await
|
109 |
if isinstance(generator, ErrorResponse):
|
110 |
return JSONResponse(content=generator.model_dump(),
|
111 |
status_code=generator.code)
|
@@ -114,19 +111,19 @@ async def detokenize(request: DetokenizeRequest):
|
|
114 |
return JSONResponse(content=generator.model_dump())
|
115 |
|
116 |
|
117 |
-
@
|
118 |
async def show_available_models():
|
119 |
-
models = await
|
120 |
return JSONResponse(content=models.model_dump())
|
121 |
|
122 |
|
123 |
-
@
|
124 |
async def show_version():
|
125 |
ver = {"version": VLLM_VERSION}
|
126 |
return JSONResponse(content=ver)
|
127 |
|
128 |
|
129 |
-
@
|
130 |
async def create_chat_completion(request: ChatCompletionRequest,
|
131 |
raw_request: Request):
|
132 |
generator = await openai_serving_chat.create_chat_completion(
|
@@ -142,7 +139,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
142 |
return JSONResponse(content=generator.model_dump())
|
143 |
|
144 |
|
145 |
-
@
|
146 |
async def create_completion(request: CompletionRequest, raw_request: Request):
|
147 |
generator = await openai_serving_completion.create_completion(
|
148 |
request, raw_request)
|
@@ -156,7 +153,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
156 |
return JSONResponse(content=generator.model_dump())
|
157 |
|
158 |
|
159 |
-
@
|
160 |
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
161 |
generator = await openai_serving_embedding.create_embedding(
|
162 |
request, raw_request)
|
@@ -167,8 +164,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|
167 |
return JSONResponse(content=generator.model_dump())
|
168 |
|
169 |
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
|
173 |
app.add_middleware(
|
174 |
CORSMiddleware,
|
@@ -178,6 +179,12 @@ if __name__ == "__main__":
|
|
178 |
allow_headers=args.allowed_headers,
|
179 |
)
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
if token := envs.VLLM_API_KEY or args.api_key:
|
182 |
|
183 |
@app.middleware("http")
|
@@ -203,6 +210,12 @@ if __name__ == "__main__":
|
|
203 |
raise ValueError(f"Invalid middleware {middleware}. "
|
204 |
f"Must be a function or a class.")
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
logger.info("vLLM API server version %s", VLLM_VERSION)
|
207 |
logger.info("args: %s", args)
|
208 |
|
@@ -211,10 +224,12 @@ if __name__ == "__main__":
|
|
211 |
else:
|
212 |
served_model_names = [args.model]
|
213 |
|
214 |
-
|
215 |
|
216 |
-
|
217 |
-
|
|
|
|
|
218 |
|
219 |
event_loop: Optional[asyncio.AbstractEventLoop]
|
220 |
try:
|
@@ -230,16 +245,57 @@ if __name__ == "__main__":
|
|
230 |
# When using single vLLM without engine_use_ray
|
231 |
model_config = asyncio.run(engine.get_model_config())
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
openai_serving_completion = OpenAIServingCompletion(
|
239 |
-
engine,
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
app.root_path = args.root_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
uvicorn.run(app,
|
244 |
host=args.host,
|
245 |
port=args.port,
|
@@ -248,4 +304,14 @@ if __name__ == "__main__":
|
|
248 |
ssl_keyfile=args.ssl_keyfile,
|
249 |
ssl_certfile=args.ssl_certfile,
|
250 |
ssl_ca_certs=args.ssl_ca_certs,
|
251 |
-
ssl_cert_reqs=args.ssl_cert_reqs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
import fastapi
|
10 |
import uvicorn
|
11 |
+
from fastapi import APIRouter, Request
|
12 |
from fastapi.exceptions import RequestValidationError
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
18 |
import vllm.envs as envs
|
19 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
20 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
21 |
+
from vllm.entrypoints.logger import RequestLogger
|
22 |
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
23 |
# yapf conflicts with isort for this block
|
24 |
# yapf: disable
|
|
|
34 |
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
35 |
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
36 |
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
37 |
+
from vllm.entrypoints.openai.serving_tokenization import (
|
38 |
+
OpenAIServingTokenization)
|
39 |
from vllm.logger import init_logger
|
40 |
from vllm.usage.usage_lib import UsageContext
|
41 |
+
from vllm.utils import FlexibleArgumentParser
|
42 |
from vllm.version import __version__ as VLLM_VERSION
|
43 |
|
44 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
45 |
|
46 |
+
engine: AsyncLLMEngine
|
47 |
+
engine_args: AsyncEngineArgs
|
48 |
openai_serving_chat: OpenAIServingChat
|
49 |
openai_serving_completion: OpenAIServingCompletion
|
50 |
openai_serving_embedding: OpenAIServingEmbedding
|
51 |
+
openai_serving_tokenization: OpenAIServingTokenization
|
52 |
|
53 |
logger = init_logger('vllm.entrypoints.openai.api_server')
|
54 |
|
|
|
71 |
yield
|
72 |
|
73 |
|
74 |
+
router = APIRouter()
|
75 |
|
76 |
|
77 |
+
def mount_metrics(app: fastapi.FastAPI):
|
78 |
+
# Add prometheus asgi middleware to route /metrics requests
|
79 |
+
metrics_route = Mount("/metrics", make_asgi_app())
|
80 |
+
# Workaround for 307 Redirect for /metrics
|
81 |
+
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
82 |
+
app.routes.append(metrics_route)
|
83 |
|
84 |
|
85 |
+
@router.get("/health")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
async def health() -> Response:
|
87 |
"""Health check."""
|
88 |
await openai_serving_chat.engine.check_health()
|
89 |
return Response(status_code=200)
|
90 |
|
91 |
|
92 |
+
@router.post("/tokenize")
|
93 |
async def tokenize(request: TokenizeRequest):
|
94 |
+
generator = await openai_serving_tokenization.create_tokenize(request)
|
95 |
if isinstance(generator, ErrorResponse):
|
96 |
return JSONResponse(content=generator.model_dump(),
|
97 |
status_code=generator.code)
|
|
|
100 |
return JSONResponse(content=generator.model_dump())
|
101 |
|
102 |
|
103 |
+
@router.post("/detokenize")
|
104 |
async def detokenize(request: DetokenizeRequest):
|
105 |
+
generator = await openai_serving_tokenization.create_detokenize(request)
|
106 |
if isinstance(generator, ErrorResponse):
|
107 |
return JSONResponse(content=generator.model_dump(),
|
108 |
status_code=generator.code)
|
|
|
111 |
return JSONResponse(content=generator.model_dump())
|
112 |
|
113 |
|
114 |
+
@router.get("/v1/models")
|
115 |
async def show_available_models():
|
116 |
+
models = await openai_serving_completion.show_available_models()
|
117 |
return JSONResponse(content=models.model_dump())
|
118 |
|
119 |
|
120 |
+
@router.get("/version")
|
121 |
async def show_version():
|
122 |
ver = {"version": VLLM_VERSION}
|
123 |
return JSONResponse(content=ver)
|
124 |
|
125 |
|
126 |
+
@router.post("/v1/chat/completions")
|
127 |
async def create_chat_completion(request: ChatCompletionRequest,
|
128 |
raw_request: Request):
|
129 |
generator = await openai_serving_chat.create_chat_completion(
|
|
|
139 |
return JSONResponse(content=generator.model_dump())
|
140 |
|
141 |
|
142 |
+
@router.post("/v1/completions")
|
143 |
async def create_completion(request: CompletionRequest, raw_request: Request):
|
144 |
generator = await openai_serving_completion.create_completion(
|
145 |
request, raw_request)
|
|
|
153 |
return JSONResponse(content=generator.model_dump())
|
154 |
|
155 |
|
156 |
+
@router.post("/v1/embeddings")
|
157 |
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
158 |
generator = await openai_serving_embedding.create_embedding(
|
159 |
request, raw_request)
|
|
|
164 |
return JSONResponse(content=generator.model_dump())
|
165 |
|
166 |
|
167 |
+
def build_app(args):
|
168 |
+
app = fastapi.FastAPI(lifespan=lifespan)
|
169 |
+
app.include_router(router)
|
170 |
+
app.root_path = args.root_path
|
171 |
+
|
172 |
+
mount_metrics(app)
|
173 |
|
174 |
app.add_middleware(
|
175 |
CORSMiddleware,
|
|
|
179 |
allow_headers=args.allowed_headers,
|
180 |
)
|
181 |
|
182 |
+
@app.exception_handler(RequestValidationError)
|
183 |
+
async def validation_exception_handler(_, exc):
|
184 |
+
err = openai_serving_chat.create_error_response(message=str(exc))
|
185 |
+
return JSONResponse(err.model_dump(),
|
186 |
+
status_code=HTTPStatus.BAD_REQUEST)
|
187 |
+
|
188 |
if token := envs.VLLM_API_KEY or args.api_key:
|
189 |
|
190 |
@app.middleware("http")
|
|
|
210 |
raise ValueError(f"Invalid middleware {middleware}. "
|
211 |
f"Must be a function or a class.")
|
212 |
|
213 |
+
return app
|
214 |
+
|
215 |
+
|
216 |
+
def run_server(args, llm_engine=None):
|
217 |
+
app = build_app(args)
|
218 |
+
|
219 |
logger.info("vLLM API server version %s", VLLM_VERSION)
|
220 |
logger.info("args: %s", args)
|
221 |
|
|
|
224 |
else:
|
225 |
served_model_names = [args.model]
|
226 |
|
227 |
+
global engine, engine_args
|
228 |
|
229 |
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
230 |
+
engine = (llm_engine
|
231 |
+
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
232 |
+
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
|
233 |
|
234 |
event_loop: Optional[asyncio.AbstractEventLoop]
|
235 |
try:
|
|
|
245 |
# When using single vLLM without engine_use_ray
|
246 |
model_config = asyncio.run(engine.get_model_config())
|
247 |
|
248 |
+
if args.disable_log_requests:
|
249 |
+
request_logger = None
|
250 |
+
else:
|
251 |
+
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
252 |
+
|
253 |
+
global openai_serving_chat
|
254 |
+
global openai_serving_completion
|
255 |
+
global openai_serving_embedding
|
256 |
+
global openai_serving_tokenization
|
257 |
+
|
258 |
+
openai_serving_chat = OpenAIServingChat(
|
259 |
+
engine,
|
260 |
+
model_config,
|
261 |
+
served_model_names,
|
262 |
+
args.response_role,
|
263 |
+
lora_modules=args.lora_modules,
|
264 |
+
prompt_adapters=args.prompt_adapters,
|
265 |
+
request_logger=request_logger,
|
266 |
+
chat_template=args.chat_template,
|
267 |
+
)
|
268 |
openai_serving_completion = OpenAIServingCompletion(
|
269 |
+
engine,
|
270 |
+
model_config,
|
271 |
+
served_model_names,
|
272 |
+
lora_modules=args.lora_modules,
|
273 |
+
prompt_adapters=args.prompt_adapters,
|
274 |
+
request_logger=request_logger,
|
275 |
+
)
|
276 |
+
openai_serving_embedding = OpenAIServingEmbedding(
|
277 |
+
engine,
|
278 |
+
model_config,
|
279 |
+
served_model_names,
|
280 |
+
request_logger=request_logger,
|
281 |
+
)
|
282 |
+
openai_serving_tokenization = OpenAIServingTokenization(
|
283 |
+
engine,
|
284 |
+
model_config,
|
285 |
+
served_model_names,
|
286 |
+
lora_modules=args.lora_modules,
|
287 |
+
request_logger=request_logger,
|
288 |
+
chat_template=args.chat_template,
|
289 |
+
)
|
290 |
app.root_path = args.root_path
|
291 |
+
|
292 |
+
logger.info("Available routes are:")
|
293 |
+
for route in app.routes:
|
294 |
+
if not hasattr(route, 'methods'):
|
295 |
+
continue
|
296 |
+
methods = ', '.join(route.methods)
|
297 |
+
logger.info("Route: %s, Methods: %s", route.path, methods)
|
298 |
+
|
299 |
uvicorn.run(app,
|
300 |
host=args.host,
|
301 |
port=args.port,
|
|
|
304 |
ssl_keyfile=args.ssl_keyfile,
|
305 |
ssl_certfile=args.ssl_certfile,
|
306 |
ssl_ca_certs=args.ssl_ca_certs,
|
307 |
+
ssl_cert_reqs=args.ssl_cert_reqs)
|
308 |
+
|
309 |
+
|
310 |
+
if __name__ == "__main__":
|
311 |
+
# NOTE(simon):
|
312 |
+
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
313 |
+
parser = FlexibleArgumentParser(
|
314 |
+
description="vLLM OpenAI-Compatible RESTful API server.")
|
315 |
+
parser = make_arg_parser(parser)
|
316 |
+
args = parser.parse_args()
|
317 |
+
run_server(args)
|
protocol.py
CHANGED
@@ -3,50 +3,16 @@
|
|
3 |
import time
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
6 |
-
import openai.types.chat
|
7 |
import torch
|
8 |
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
9 |
-
|
10 |
-
from typing_extensions import Annotated, Required, TypedDict
|
11 |
|
|
|
12 |
from vllm.pooling_params import PoolingParams
|
13 |
from vllm.sampling_params import SamplingParams
|
14 |
from vllm.utils import random_uuid
|
15 |
|
16 |
|
17 |
-
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
18 |
-
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
19 |
-
|
20 |
-
type: Required[str]
|
21 |
-
"""The type of the content part."""
|
22 |
-
|
23 |
-
|
24 |
-
ChatCompletionContentPartParam = Union[
|
25 |
-
openai.types.chat.ChatCompletionContentPartParam,
|
26 |
-
CustomChatCompletionContentPartParam]
|
27 |
-
|
28 |
-
|
29 |
-
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
30 |
-
"""Enables custom roles in the Chat Completion API."""
|
31 |
-
role: Required[str]
|
32 |
-
"""The role of the message's author."""
|
33 |
-
|
34 |
-
content: Union[str, List[ChatCompletionContentPartParam]]
|
35 |
-
"""The contents of the message."""
|
36 |
-
|
37 |
-
name: str
|
38 |
-
"""An optional name for the participant.
|
39 |
-
|
40 |
-
Provides the model information to differentiate between participants of the
|
41 |
-
same role.
|
42 |
-
"""
|
43 |
-
|
44 |
-
|
45 |
-
ChatCompletionMessageParam = Union[
|
46 |
-
openai.types.chat.ChatCompletionMessageParam,
|
47 |
-
CustomChatCompletionMessageParam]
|
48 |
-
|
49 |
-
|
50 |
class OpenAIBaseModel(BaseModel):
|
51 |
# OpenAI API does not allow extra fields
|
52 |
model_config = ConfigDict(extra="forbid")
|
@@ -155,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
155 |
|
156 |
# doc: begin-chat-completion-sampling-params
|
157 |
best_of: Optional[int] = None
|
158 |
-
use_beam_search:
|
159 |
-
top_k:
|
160 |
-
min_p:
|
161 |
-
repetition_penalty:
|
162 |
-
length_penalty:
|
163 |
-
early_stopping:
|
164 |
-
ignore_eos: Optional[bool] = False
|
165 |
-
min_tokens: Optional[int] = 0
|
166 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
# doc: end-chat-completion-sampling-params
|
170 |
|
171 |
# doc: begin-chat-completion-extra-params
|
172 |
-
echo:
|
173 |
default=False,
|
174 |
description=(
|
175 |
"If true, the new message will be prepended with the last message "
|
176 |
"if they belong to the same role."),
|
177 |
)
|
178 |
-
add_generation_prompt:
|
179 |
default=True,
|
180 |
description=
|
181 |
("If true, the generation prompt will be added to the chat template. "
|
182 |
"This is a parameter used by chat template in tokenizer config of the "
|
183 |
"model."),
|
184 |
)
|
185 |
-
add_special_tokens:
|
186 |
default=False,
|
187 |
description=(
|
188 |
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
189 |
"on top of what is added by the chat template. "
|
190 |
"For most models, the chat template takes care of adding the "
|
191 |
-
"special tokens so this should be set to
|
192 |
"default)."),
|
193 |
)
|
194 |
documents: Optional[List[Dict[str, str]]] = Field(
|
@@ -212,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
212 |
description=("Additional kwargs to pass to the template renderer. "
|
213 |
"Will be accessible by the chat template."),
|
214 |
)
|
215 |
-
include_stop_str_in_output: Optional[bool] = Field(
|
216 |
-
default=False,
|
217 |
-
description=(
|
218 |
-
"Whether to include the stop string in the output. "
|
219 |
-
"This is only applied when the stop or stop_token_ids is set."),
|
220 |
-
)
|
221 |
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
222 |
default=None,
|
223 |
description=("If specified, the output will follow the JSON schema."),
|
@@ -278,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
278 |
|
279 |
return SamplingParams(
|
280 |
n=self.n,
|
|
|
281 |
presence_penalty=self.presence_penalty,
|
282 |
frequency_penalty=self.frequency_penalty,
|
283 |
repetition_penalty=self.repetition_penalty,
|
284 |
temperature=self.temperature,
|
285 |
top_p=self.top_p,
|
|
|
286 |
min_p=self.min_p,
|
287 |
seed=self.seed,
|
288 |
stop=self.stop,
|
289 |
stop_token_ids=self.stop_token_ids,
|
290 |
-
max_tokens=self.max_tokens,
|
291 |
-
min_tokens=self.min_tokens,
|
292 |
logprobs=self.top_logprobs if self.logprobs else None,
|
293 |
prompt_logprobs=self.top_logprobs if self.echo else None,
|
294 |
-
best_of=self.best_of,
|
295 |
-
top_k=self.top_k,
|
296 |
ignore_eos=self.ignore_eos,
|
|
|
|
|
297 |
use_beam_search=self.use_beam_search,
|
298 |
early_stopping=self.early_stopping,
|
299 |
skip_special_tokens=self.skip_special_tokens,
|
@@ -301,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
301 |
include_stop_str_in_output=self.include_stop_str_in_output,
|
302 |
length_penalty=self.length_penalty,
|
303 |
logits_processors=logits_processors,
|
|
|
304 |
)
|
305 |
|
306 |
@model_validator(mode='before')
|
@@ -382,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
|
|
382 |
user: Optional[str] = None
|
383 |
|
384 |
# doc: begin-completion-sampling-params
|
385 |
-
use_beam_search:
|
386 |
-
top_k:
|
387 |
-
min_p:
|
388 |
-
repetition_penalty:
|
389 |
-
length_penalty:
|
390 |
-
early_stopping:
|
391 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
396 |
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
397 |
# doc: end-completion-sampling-params
|
398 |
|
399 |
# doc: begin-completion-extra-params
|
400 |
-
|
401 |
-
default=
|
402 |
description=(
|
403 |
-
"
|
404 |
-
"
|
405 |
)
|
406 |
response_format: Optional[ResponseFormat] = Field(
|
407 |
default=None,
|
@@ -481,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
|
|
481 |
seed=self.seed,
|
482 |
stop=self.stop,
|
483 |
stop_token_ids=self.stop_token_ids,
|
|
|
484 |
ignore_eos=self.ignore_eos,
|
485 |
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
486 |
min_tokens=self.min_tokens,
|
487 |
-
logprobs=self.logprobs,
|
488 |
use_beam_search=self.use_beam_search,
|
489 |
early_stopping=self.early_stopping,
|
490 |
prompt_logprobs=self.logprobs if self.echo else None,
|
491 |
skip_special_tokens=self.skip_special_tokens,
|
492 |
-
spaces_between_special_tokens=
|
493 |
include_stop_str_in_output=self.include_stop_str_in_output,
|
494 |
length_penalty=self.length_penalty,
|
495 |
logits_processors=logits_processors,
|
@@ -523,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
|
|
523 |
def validate_stream_options(cls, data):
|
524 |
if data.get("stream_options") and not data.get("stream"):
|
525 |
raise ValueError(
|
526 |
-
"Stream options can only be defined when stream is
|
527 |
return data
|
528 |
|
529 |
|
530 |
-
class EmbeddingRequest(
|
531 |
# Ordered by official OpenAI API documentation
|
532 |
# https://platform.openai.com/docs/api-reference/embeddings
|
533 |
model: str
|
@@ -599,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|
599 |
usage: Optional[UsageInfo] = Field(default=None)
|
600 |
|
601 |
|
602 |
-
class EmbeddingResponseData(
|
603 |
index: int
|
604 |
object: str = "embedding"
|
605 |
embedding: Union[List[float], str]
|
606 |
|
607 |
|
608 |
-
class EmbeddingResponse(
|
609 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
610 |
object: str = "list"
|
611 |
created: int = Field(default_factory=lambda: int(time.time()))
|
@@ -704,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
|
|
704 |
# /v1/chat/completions is supported.
|
705 |
url: str
|
706 |
|
707 |
-
# The
|
708 |
-
body:
|
709 |
|
710 |
|
711 |
class BatchResponseData(OpenAIBaseModel):
|
@@ -716,7 +680,7 @@ class BatchResponseData(OpenAIBaseModel):
|
|
716 |
request_id: str
|
717 |
|
718 |
# The body of the response.
|
719 |
-
body:
|
720 |
|
721 |
|
722 |
class BatchRequestOutput(OpenAIBaseModel):
|
@@ -737,16 +701,28 @@ class BatchRequestOutput(OpenAIBaseModel):
|
|
737 |
error: Optional[Any]
|
738 |
|
739 |
|
740 |
-
class
|
741 |
model: str
|
742 |
prompt: str
|
|
|
743 |
add_special_tokens: bool = Field(default=True)
|
744 |
|
745 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
class TokenizeResponse(OpenAIBaseModel):
|
747 |
-
tokens: List[int]
|
748 |
count: int
|
749 |
max_model_len: int
|
|
|
750 |
|
751 |
|
752 |
class DetokenizeRequest(OpenAIBaseModel):
|
|
|
3 |
import time
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
|
|
6 |
import torch
|
7 |
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
8 |
+
from typing_extensions import Annotated
|
|
|
9 |
|
10 |
+
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
11 |
from vllm.pooling_params import PoolingParams
|
12 |
from vllm.sampling_params import SamplingParams
|
13 |
from vllm.utils import random_uuid
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class OpenAIBaseModel(BaseModel):
|
17 |
# OpenAI API does not allow extra fields
|
18 |
model_config = ConfigDict(extra="forbid")
|
|
|
121 |
|
122 |
# doc: begin-chat-completion-sampling-params
|
123 |
best_of: Optional[int] = None
|
124 |
+
use_beam_search: bool = False
|
125 |
+
top_k: int = -1
|
126 |
+
min_p: float = 0.0
|
127 |
+
repetition_penalty: float = 1.0
|
128 |
+
length_penalty: float = 1.0
|
129 |
+
early_stopping: bool = False
|
|
|
|
|
130 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
131 |
+
include_stop_str_in_output: bool = False
|
132 |
+
ignore_eos: bool = False
|
133 |
+
min_tokens: int = 0
|
134 |
+
skip_special_tokens: bool = True
|
135 |
+
spaces_between_special_tokens: bool = True
|
136 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
137 |
# doc: end-chat-completion-sampling-params
|
138 |
|
139 |
# doc: begin-chat-completion-extra-params
|
140 |
+
echo: bool = Field(
|
141 |
default=False,
|
142 |
description=(
|
143 |
"If true, the new message will be prepended with the last message "
|
144 |
"if they belong to the same role."),
|
145 |
)
|
146 |
+
add_generation_prompt: bool = Field(
|
147 |
default=True,
|
148 |
description=
|
149 |
("If true, the generation prompt will be added to the chat template. "
|
150 |
"This is a parameter used by chat template in tokenizer config of the "
|
151 |
"model."),
|
152 |
)
|
153 |
+
add_special_tokens: bool = Field(
|
154 |
default=False,
|
155 |
description=(
|
156 |
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
157 |
"on top of what is added by the chat template. "
|
158 |
"For most models, the chat template takes care of adding the "
|
159 |
+
"special tokens so this should be set to false (as is the "
|
160 |
"default)."),
|
161 |
)
|
162 |
documents: Optional[List[Dict[str, str]]] = Field(
|
|
|
180 |
description=("Additional kwargs to pass to the template renderer. "
|
181 |
"Will be accessible by the chat template."),
|
182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
184 |
default=None,
|
185 |
description=("If specified, the output will follow the JSON schema."),
|
|
|
240 |
|
241 |
return SamplingParams(
|
242 |
n=self.n,
|
243 |
+
best_of=self.best_of,
|
244 |
presence_penalty=self.presence_penalty,
|
245 |
frequency_penalty=self.frequency_penalty,
|
246 |
repetition_penalty=self.repetition_penalty,
|
247 |
temperature=self.temperature,
|
248 |
top_p=self.top_p,
|
249 |
+
top_k=self.top_k,
|
250 |
min_p=self.min_p,
|
251 |
seed=self.seed,
|
252 |
stop=self.stop,
|
253 |
stop_token_ids=self.stop_token_ids,
|
|
|
|
|
254 |
logprobs=self.top_logprobs if self.logprobs else None,
|
255 |
prompt_logprobs=self.top_logprobs if self.echo else None,
|
|
|
|
|
256 |
ignore_eos=self.ignore_eos,
|
257 |
+
max_tokens=self.max_tokens,
|
258 |
+
min_tokens=self.min_tokens,
|
259 |
use_beam_search=self.use_beam_search,
|
260 |
early_stopping=self.early_stopping,
|
261 |
skip_special_tokens=self.skip_special_tokens,
|
|
|
263 |
include_stop_str_in_output=self.include_stop_str_in_output,
|
264 |
length_penalty=self.length_penalty,
|
265 |
logits_processors=logits_processors,
|
266 |
+
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
267 |
)
|
268 |
|
269 |
@model_validator(mode='before')
|
|
|
345 |
user: Optional[str] = None
|
346 |
|
347 |
# doc: begin-completion-sampling-params
|
348 |
+
use_beam_search: bool = False
|
349 |
+
top_k: int = -1
|
350 |
+
min_p: float = 0.0
|
351 |
+
repetition_penalty: float = 1.0
|
352 |
+
length_penalty: float = 1.0
|
353 |
+
early_stopping: bool = False
|
354 |
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
355 |
+
include_stop_str_in_output: bool = False
|
356 |
+
ignore_eos: bool = False
|
357 |
+
min_tokens: int = 0
|
358 |
+
skip_special_tokens: bool = True
|
359 |
+
spaces_between_special_tokens: bool = True
|
360 |
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
361 |
# doc: end-completion-sampling-params
|
362 |
|
363 |
# doc: begin-completion-extra-params
|
364 |
+
add_special_tokens: bool = Field(
|
365 |
+
default=True,
|
366 |
description=(
|
367 |
+
"If true (the default), special tokens (e.g. BOS) will be added to "
|
368 |
+
"the prompt."),
|
369 |
)
|
370 |
response_format: Optional[ResponseFormat] = Field(
|
371 |
default=None,
|
|
|
445 |
seed=self.seed,
|
446 |
stop=self.stop,
|
447 |
stop_token_ids=self.stop_token_ids,
|
448 |
+
logprobs=self.logprobs,
|
449 |
ignore_eos=self.ignore_eos,
|
450 |
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
451 |
min_tokens=self.min_tokens,
|
|
|
452 |
use_beam_search=self.use_beam_search,
|
453 |
early_stopping=self.early_stopping,
|
454 |
prompt_logprobs=self.logprobs if self.echo else None,
|
455 |
skip_special_tokens=self.skip_special_tokens,
|
456 |
+
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
457 |
include_stop_str_in_output=self.include_stop_str_in_output,
|
458 |
length_penalty=self.length_penalty,
|
459 |
logits_processors=logits_processors,
|
|
|
487 |
def validate_stream_options(cls, data):
|
488 |
if data.get("stream_options") and not data.get("stream"):
|
489 |
raise ValueError(
|
490 |
+
"Stream options can only be defined when stream is true.")
|
491 |
return data
|
492 |
|
493 |
|
494 |
+
class EmbeddingRequest(OpenAIBaseModel):
|
495 |
# Ordered by official OpenAI API documentation
|
496 |
# https://platform.openai.com/docs/api-reference/embeddings
|
497 |
model: str
|
|
|
563 |
usage: Optional[UsageInfo] = Field(default=None)
|
564 |
|
565 |
|
566 |
+
class EmbeddingResponseData(OpenAIBaseModel):
|
567 |
index: int
|
568 |
object: str = "embedding"
|
569 |
embedding: Union[List[float], str]
|
570 |
|
571 |
|
572 |
+
class EmbeddingResponse(OpenAIBaseModel):
|
573 |
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
574 |
object: str = "list"
|
575 |
created: int = Field(default_factory=lambda: int(time.time()))
|
|
|
668 |
# /v1/chat/completions is supported.
|
669 |
url: str
|
670 |
|
671 |
+
# The parameters of the request.
|
672 |
+
body: ChatCompletionRequest
|
673 |
|
674 |
|
675 |
class BatchResponseData(OpenAIBaseModel):
|
|
|
680 |
request_id: str
|
681 |
|
682 |
# The body of the response.
|
683 |
+
body: Optional[ChatCompletionResponse] = None
|
684 |
|
685 |
|
686 |
class BatchRequestOutput(OpenAIBaseModel):
|
|
|
701 |
error: Optional[Any]
|
702 |
|
703 |
|
704 |
+
class TokenizeCompletionRequest(OpenAIBaseModel):
|
705 |
model: str
|
706 |
prompt: str
|
707 |
+
|
708 |
add_special_tokens: bool = Field(default=True)
|
709 |
|
710 |
|
711 |
+
class TokenizeChatRequest(OpenAIBaseModel):
|
712 |
+
model: str
|
713 |
+
messages: List[ChatCompletionMessageParam]
|
714 |
+
|
715 |
+
add_generation_prompt: bool = Field(default=True)
|
716 |
+
add_special_tokens: bool = Field(default=False)
|
717 |
+
|
718 |
+
|
719 |
+
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
720 |
+
|
721 |
+
|
722 |
class TokenizeResponse(OpenAIBaseModel):
|
|
|
723 |
count: int
|
724 |
max_model_len: int
|
725 |
+
tokens: List[int]
|
726 |
|
727 |
|
728 |
class DetokenizeRequest(OpenAIBaseModel):
|
serving_chat.py
CHANGED
@@ -1,34 +1,33 @@
|
|
1 |
-
import codecs
|
2 |
import time
|
3 |
-
from
|
4 |
-
|
5 |
-
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
|
6 |
-
List, Optional)
|
7 |
from typing import Sequence as GenericSequence
|
8 |
-
from typing import
|
9 |
|
10 |
from fastapi import Request
|
11 |
-
from
|
12 |
-
ChatCompletionContentPartTextParam)
|
13 |
|
14 |
from vllm.config import ModelConfig
|
15 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
|
|
|
|
|
|
16 |
from vllm.entrypoints.openai.protocol import (
|
17 |
-
|
18 |
-
|
19 |
-
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
20 |
ChatCompletionRequest, ChatCompletionResponse,
|
21 |
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
22 |
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
23 |
FunctionCall, ToolCall, UsageInfo)
|
24 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
25 |
-
OpenAIServing
|
|
|
26 |
from vllm.inputs import PromptInputs
|
27 |
from vllm.logger import init_logger
|
28 |
from vllm.model_executor.guided_decoding import (
|
29 |
get_guided_decoding_logits_processor)
|
30 |
from vllm.multimodal import MultiModalDataDict
|
31 |
-
from vllm.multimodal.utils import async_get_and_parse_image
|
32 |
from vllm.outputs import RequestOutput
|
33 |
from vllm.sequence import Logprob
|
34 |
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
@@ -38,159 +37,31 @@ from vllm.utils import random_uuid
|
|
38 |
logger = init_logger(__name__)
|
39 |
|
40 |
|
41 |
-
@final # So that it should be compatible with Dict[str, str]
|
42 |
-
class ConversationMessage(TypedDict):
|
43 |
-
role: str
|
44 |
-
content: str
|
45 |
-
|
46 |
-
|
47 |
-
@dataclass(frozen=True)
|
48 |
-
class ChatMessageParseResult:
|
49 |
-
messages: List[ConversationMessage]
|
50 |
-
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
51 |
-
default_factory=list)
|
52 |
-
|
53 |
-
|
54 |
class OpenAIServingChat(OpenAIServing):
|
55 |
|
56 |
-
def __init__(
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
super().__init__(engine=engine,
|
64 |
model_config=model_config,
|
65 |
served_model_names=served_model_names,
|
66 |
-
lora_modules=lora_modules
|
|
|
|
|
67 |
|
68 |
self.response_role = response_role
|
69 |
-
self._load_chat_template(chat_template)
|
70 |
-
|
71 |
-
def _load_chat_template(self, chat_template: Optional[str]):
|
72 |
-
tokenizer = self.tokenizer
|
73 |
-
|
74 |
-
if chat_template is not None:
|
75 |
-
try:
|
76 |
-
with open(chat_template, "r") as f:
|
77 |
-
tokenizer.chat_template = f.read()
|
78 |
-
except OSError as e:
|
79 |
-
JINJA_CHARS = "{}\n"
|
80 |
-
if not any(c in chat_template for c in JINJA_CHARS):
|
81 |
-
msg = (f"The supplied chat template ({chat_template}) "
|
82 |
-
f"looks like a file path, but it failed to be "
|
83 |
-
f"opened. Reason: {e}")
|
84 |
-
raise ValueError(msg) from e
|
85 |
-
|
86 |
-
# If opening a file fails, set chat template to be args to
|
87 |
-
# ensure we decode so our escape are interpreted correctly
|
88 |
-
tokenizer.chat_template = codecs.decode(
|
89 |
-
chat_template, "unicode_escape")
|
90 |
-
|
91 |
-
logger.info("Using supplied chat template:\n%s",
|
92 |
-
tokenizer.chat_template)
|
93 |
-
elif tokenizer.chat_template is not None:
|
94 |
-
logger.info("Using default chat template:\n%s",
|
95 |
-
tokenizer.chat_template)
|
96 |
-
else:
|
97 |
-
logger.warning(
|
98 |
-
"No chat template provided. Chat API will not work.")
|
99 |
-
|
100 |
-
@cached_property
|
101 |
-
def image_token_str(self) -> Optional[str]:
|
102 |
-
# TODO: Let user specify how to insert image tokens into prompt
|
103 |
-
# (similar to chat template)
|
104 |
-
model_type = self.model_config.hf_config.model_type
|
105 |
-
if model_type == "phi3_v":
|
106 |
-
# Workaround since this token is not defined in the tokenizer
|
107 |
-
return "<|image_1|>"
|
108 |
-
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
|
109 |
-
"paligemma"):
|
110 |
-
# These models do not use image tokens in the prompt
|
111 |
-
return None
|
112 |
-
if model_type.startswith("llava"):
|
113 |
-
return self.tokenizer.decode(
|
114 |
-
self.model_config.hf_config.image_token_index)
|
115 |
-
|
116 |
-
else:
|
117 |
-
raise TypeError("Unknown model type: {model_type}")
|
118 |
-
|
119 |
-
# TODO: Let user specify how to insert image tokens into prompt
|
120 |
-
# (similar to chat template)
|
121 |
-
def _get_full_image_text_prompt(self, image_token_str: str,
|
122 |
-
text_prompt: str) -> str:
|
123 |
-
"""Combine image and text prompts for vision language model"""
|
124 |
-
|
125 |
-
# NOTE: For now we assume all model architectures use the same
|
126 |
-
# image + text prompt format. This may change in the future.
|
127 |
-
return f"{image_token_str}\n{text_prompt}"
|
128 |
-
|
129 |
-
def _parse_chat_message_content_parts(
|
130 |
-
self,
|
131 |
-
role: str,
|
132 |
-
parts: Iterable[ChatCompletionContentPartParam],
|
133 |
-
) -> ChatMessageParseResult:
|
134 |
-
texts: List[str] = []
|
135 |
-
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
136 |
-
|
137 |
-
for part in parts:
|
138 |
-
part_type = part["type"]
|
139 |
-
if part_type == "text":
|
140 |
-
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
141 |
-
texts.append(text)
|
142 |
-
elif part_type == "image_url":
|
143 |
-
if len(mm_futures) > 0:
|
144 |
-
raise NotImplementedError(
|
145 |
-
"Multiple 'image_url' input is currently not supported."
|
146 |
-
)
|
147 |
-
|
148 |
-
image_url = cast(ChatCompletionContentPartImageParam,
|
149 |
-
part)["image_url"]
|
150 |
-
|
151 |
-
if image_url.get("detail", "auto") != "auto":
|
152 |
-
logger.warning(
|
153 |
-
"'image_url.detail' is currently not supported and "
|
154 |
-
"will be ignored.")
|
155 |
-
|
156 |
-
image_future = async_get_and_parse_image(image_url["url"])
|
157 |
-
mm_futures.append(image_future)
|
158 |
-
else:
|
159 |
-
raise NotImplementedError(f"Unknown part type: {part_type}")
|
160 |
-
|
161 |
-
text_prompt = "\n".join(texts)
|
162 |
-
|
163 |
-
if mm_futures:
|
164 |
-
image_token_str = self.image_token_str
|
165 |
-
if image_token_str is not None:
|
166 |
-
if image_token_str in text_prompt:
|
167 |
-
logger.warning(
|
168 |
-
"Detected image token string in the text prompt. "
|
169 |
-
"Skipping prompt formatting.")
|
170 |
-
else:
|
171 |
-
text_prompt = self._get_full_image_text_prompt(
|
172 |
-
image_token_str=image_token_str,
|
173 |
-
text_prompt=text_prompt,
|
174 |
-
)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
179 |
-
|
180 |
-
def _parse_chat_message_content(
|
181 |
-
self,
|
182 |
-
message: ChatCompletionMessageParam,
|
183 |
-
) -> ChatMessageParseResult:
|
184 |
-
role = message["role"]
|
185 |
-
content = message.get("content")
|
186 |
-
|
187 |
-
if content is None:
|
188 |
-
return ChatMessageParseResult(messages=[], mm_futures=[])
|
189 |
-
if isinstance(content, str):
|
190 |
-
messages = [ConversationMessage(role=role, content=content)]
|
191 |
-
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
192 |
-
|
193 |
-
return self._parse_chat_message_content_parts(role, content)
|
194 |
|
195 |
async def create_chat_completion(
|
196 |
self,
|
@@ -212,11 +83,20 @@ class OpenAIServingChat(OpenAIServing):
|
|
212 |
return error_check_ret
|
213 |
|
214 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
conversation: List[ConversationMessage] = []
|
216 |
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
217 |
|
218 |
for msg in request.messages:
|
219 |
-
chat_parsed_result =
|
|
|
220 |
|
221 |
conversation.extend(chat_parsed_result.messages)
|
222 |
mm_futures.extend(chat_parsed_result.mm_futures)
|
@@ -225,13 +105,13 @@ class OpenAIServingChat(OpenAIServing):
|
|
225 |
tool.model_dump() for tool in request.tools
|
226 |
]
|
227 |
|
228 |
-
prompt =
|
229 |
conversation=conversation,
|
230 |
tokenize=False,
|
231 |
add_generation_prompt=request.add_generation_prompt,
|
232 |
tools=tool_dicts,
|
233 |
documents=request.documents,
|
234 |
-
chat_template=request.chat_template,
|
235 |
**(request.chat_template_kwargs or {}),
|
236 |
)
|
237 |
except Exception as e:
|
@@ -250,61 +130,71 @@ class OpenAIServingChat(OpenAIServing):
|
|
250 |
logger.error("Error in loading multi-modal data: %s", e)
|
251 |
return self.create_error_response(str(e))
|
252 |
|
253 |
-
request_id = f"
|
254 |
try:
|
255 |
-
# Tokenize/detokenize depending on prompt format (string/token list)
|
256 |
-
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
257 |
-
request,
|
258 |
-
prompt=prompt,
|
259 |
-
add_special_tokens=request.add_special_tokens)
|
260 |
sampling_params = request.to_sampling_params()
|
261 |
-
lora_request = self._maybe_get_lora(request)
|
262 |
decoding_config = await self.engine.get_decoding_config()
|
263 |
guided_decoding_backend = request.guided_decoding_backend \
|
264 |
or decoding_config.guided_decoding_backend
|
265 |
guided_decode_logits_processor = (
|
266 |
-
await
|
267 |
-
|
268 |
-
|
269 |
if guided_decode_logits_processor:
|
270 |
if sampling_params.logits_processors is None:
|
271 |
sampling_params.logits_processors = []
|
272 |
sampling_params.logits_processors.append(
|
273 |
guided_decode_logits_processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
except ValueError as e:
|
|
|
275 |
return self.create_error_response(str(e))
|
276 |
|
277 |
-
inputs: PromptInputs = {
|
278 |
-
"prompt": prompt_text,
|
279 |
-
"prompt_token_ids": prompt_ids,
|
280 |
-
}
|
281 |
-
if mm_data:
|
282 |
-
inputs["multi_modal_data"] = mm_data
|
283 |
-
|
284 |
-
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
285 |
-
trace_headers = None
|
286 |
-
if is_tracing_enabled and raw_request:
|
287 |
-
trace_headers = extract_trace_headers(raw_request.headers)
|
288 |
-
if not is_tracing_enabled and raw_request and contains_trace_headers(
|
289 |
-
raw_request.headers):
|
290 |
-
log_tracing_disabled_warning()
|
291 |
-
|
292 |
-
result_generator = self.engine.generate(
|
293 |
-
inputs,
|
294 |
-
sampling_params,
|
295 |
-
request_id,
|
296 |
-
lora_request,
|
297 |
-
trace_headers=trace_headers,
|
298 |
-
)
|
299 |
# Streaming response
|
300 |
if request.stream:
|
301 |
return self.chat_completion_stream_generator(
|
302 |
-
request, result_generator, request_id, conversation)
|
303 |
else:
|
304 |
try:
|
305 |
return await self.chat_completion_full_generator(
|
306 |
request, raw_request, result_generator, request_id,
|
307 |
-
conversation)
|
308 |
except ValueError as e:
|
309 |
# TODO: Use a vllm-specific Validation Error
|
310 |
return self.create_error_response(str(e))
|
@@ -316,9 +206,12 @@ class OpenAIServingChat(OpenAIServing):
|
|
316 |
return request.messages[-1]["role"]
|
317 |
|
318 |
async def chat_completion_stream_generator(
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
322 |
) -> AsyncGenerator[str, None]:
|
323 |
model_name = self.served_model_names[0]
|
324 |
created_time = int(time.time())
|
@@ -326,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
|
|
326 |
first_iteration = True
|
327 |
|
328 |
# Send response for each token for each request.n (index)
|
329 |
-
|
330 |
-
previous_texts = [""] *
|
331 |
-
previous_num_tokens = [0] *
|
332 |
-
finish_reason_sent = [False] *
|
|
|
333 |
try:
|
334 |
async for res in result_generator:
|
335 |
# We need to do it here, because if there are exceptions in
|
@@ -339,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
|
|
339 |
# Send first response for each request.n (index) with
|
340 |
# the role
|
341 |
role = self.get_chat_request_role(request)
|
342 |
-
for i in range(
|
343 |
choice_data = ChatCompletionResponseStreamChoice(
|
344 |
index=i,
|
345 |
delta=DeltaMessage(role=role),
|
@@ -367,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
|
|
367 |
last_msg_content = conversation[-1]["content"]
|
368 |
|
369 |
if last_msg_content:
|
370 |
-
for i in range(
|
371 |
choice_data = (
|
372 |
ChatCompletionResponseStreamChoice(
|
373 |
index=i,
|
374 |
delta=DeltaMessage(
|
375 |
content=last_msg_content),
|
|
|
376 |
finish_reason=None))
|
377 |
chunk = ChatCompletionStreamResponse(
|
378 |
id=request_id,
|
379 |
object=chunk_object_type,
|
380 |
created=created_time,
|
381 |
choices=[choice_data],
|
382 |
-
logprobs=None,
|
383 |
model=model_name)
|
384 |
if (request.stream_options and
|
385 |
request.stream_options.include_usage):
|
@@ -405,6 +299,7 @@ class OpenAIServingChat(OpenAIServing):
|
|
405 |
logprobs = self._create_chat_logprobs(
|
406 |
token_ids=delta_token_ids,
|
407 |
top_logprobs=out_logprobs,
|
|
|
408 |
num_output_top_logprobs=request.top_logprobs,
|
409 |
)
|
410 |
else:
|
@@ -493,9 +388,13 @@ class OpenAIServingChat(OpenAIServing):
|
|
493 |
yield "data: [DONE]\n\n"
|
494 |
|
495 |
async def chat_completion_full_generator(
|
496 |
-
self,
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
499 |
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
500 |
|
501 |
model_name = self.served_model_names[0]
|
@@ -523,6 +422,7 @@ class OpenAIServingChat(OpenAIServing):
|
|
523 |
token_ids=token_ids,
|
524 |
top_logprobs=out_logprobs,
|
525 |
num_output_top_logprobs=request.top_logprobs,
|
|
|
526 |
)
|
527 |
else:
|
528 |
logprobs = None
|
@@ -577,16 +477,14 @@ class OpenAIServingChat(OpenAIServing):
|
|
577 |
return response
|
578 |
|
579 |
def _get_top_logprobs(
|
580 |
-
self, logprobs: Dict[int, Logprob],
|
581 |
-
|
582 |
return [
|
583 |
ChatCompletionLogProb(
|
584 |
-
token=self._get_decoded_token(p[1], p[0]
|
|
|
585 |
logprob=max(p[1].logprob, -9999.0),
|
586 |
-
bytes=list(
|
587 |
-
self._get_decoded_token(p[1],
|
588 |
-
p[0]).encode("utf-8",
|
589 |
-
errors="replace")))
|
590 |
for i, p in enumerate(logprobs.items())
|
591 |
if top_logprobs and i < top_logprobs
|
592 |
]
|
@@ -595,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
|
|
595 |
self,
|
596 |
token_ids: GenericSequence[int],
|
597 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
|
|
598 |
num_output_top_logprobs: Optional[int] = None,
|
599 |
) -> ChatCompletionLogProbs:
|
600 |
"""Create OpenAI-style logprobs."""
|
@@ -604,12 +503,11 @@ class OpenAIServingChat(OpenAIServing):
|
|
604 |
for i, token_id in enumerate(token_ids):
|
605 |
step_top_logprobs = top_logprobs[i]
|
606 |
if step_top_logprobs is None:
|
|
|
607 |
logprobs_content.append(
|
608 |
ChatCompletionLogProbsContent(
|
609 |
-
token=
|
610 |
-
bytes=list(
|
611 |
-
self.tokenizer.decode(token_id).encode(
|
612 |
-
"utf-8", errors="replace"))))
|
613 |
else:
|
614 |
logprobs_content.append(
|
615 |
ChatCompletionLogProbsContent(
|
@@ -620,6 +518,7 @@ class OpenAIServingChat(OpenAIServing):
|
|
620 |
step_top_logprobs[token_id].decoded_token.encode(
|
621 |
"utf-8", errors="replace")),
|
622 |
top_logprobs=self._get_top_logprobs(
|
623 |
-
step_top_logprobs, num_output_top_logprobs
|
|
|
624 |
|
625 |
return ChatCompletionLogProbs(content=logprobs_content)
|
|
|
|
|
1 |
import time
|
2 |
+
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
|
3 |
+
Optional)
|
|
|
|
|
4 |
from typing import Sequence as GenericSequence
|
5 |
+
from typing import Union
|
6 |
|
7 |
from fastapi import Request
|
8 |
+
from transformers import PreTrainedTokenizer
|
|
|
9 |
|
10 |
from vllm.config import ModelConfig
|
11 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
12 |
+
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
13 |
+
load_chat_template,
|
14 |
+
parse_chat_message_content)
|
15 |
+
from vllm.entrypoints.logger import RequestLogger
|
16 |
from vllm.entrypoints.openai.protocol import (
|
17 |
+
ChatCompletionLogProb, ChatCompletionLogProbs,
|
18 |
+
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
|
|
19 |
ChatCompletionRequest, ChatCompletionResponse,
|
20 |
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
21 |
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
22 |
FunctionCall, ToolCall, UsageInfo)
|
23 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
24 |
+
OpenAIServing,
|
25 |
+
PromptAdapterPath)
|
26 |
from vllm.inputs import PromptInputs
|
27 |
from vllm.logger import init_logger
|
28 |
from vllm.model_executor.guided_decoding import (
|
29 |
get_guided_decoding_logits_processor)
|
30 |
from vllm.multimodal import MultiModalDataDict
|
|
|
31 |
from vllm.outputs import RequestOutput
|
32 |
from vllm.sequence import Logprob
|
33 |
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
|
|
37 |
logger = init_logger(__name__)
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
class OpenAIServingChat(OpenAIServing):
|
41 |
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
engine: AsyncLLMEngine,
|
45 |
+
model_config: ModelConfig,
|
46 |
+
served_model_names: List[str],
|
47 |
+
response_role: str,
|
48 |
+
*,
|
49 |
+
lora_modules: Optional[List[LoRAModulePath]],
|
50 |
+
prompt_adapters: Optional[List[PromptAdapterPath]],
|
51 |
+
request_logger: Optional[RequestLogger],
|
52 |
+
chat_template: Optional[str],
|
53 |
+
):
|
54 |
super().__init__(engine=engine,
|
55 |
model_config=model_config,
|
56 |
served_model_names=served_model_names,
|
57 |
+
lora_modules=lora_modules,
|
58 |
+
prompt_adapters=prompt_adapters,
|
59 |
+
request_logger=request_logger)
|
60 |
|
61 |
self.response_role = response_role
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
# If this is None we use the tokenizer's default chat template
|
64 |
+
self.chat_template = load_chat_template(chat_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
async def create_chat_completion(
|
67 |
self,
|
|
|
83 |
return error_check_ret
|
84 |
|
85 |
try:
|
86 |
+
(
|
87 |
+
lora_request,
|
88 |
+
prompt_adapter_request,
|
89 |
+
) = self._maybe_get_adapters(request)
|
90 |
+
|
91 |
+
model_config = self.model_config
|
92 |
+
tokenizer = await self.engine.get_tokenizer(lora_request)
|
93 |
+
|
94 |
conversation: List[ConversationMessage] = []
|
95 |
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
96 |
|
97 |
for msg in request.messages:
|
98 |
+
chat_parsed_result = parse_chat_message_content(
|
99 |
+
msg, model_config, tokenizer)
|
100 |
|
101 |
conversation.extend(chat_parsed_result.messages)
|
102 |
mm_futures.extend(chat_parsed_result.mm_futures)
|
|
|
105 |
tool.model_dump() for tool in request.tools
|
106 |
]
|
107 |
|
108 |
+
prompt = tokenizer.apply_chat_template(
|
109 |
conversation=conversation,
|
110 |
tokenize=False,
|
111 |
add_generation_prompt=request.add_generation_prompt,
|
112 |
tools=tool_dicts,
|
113 |
documents=request.documents,
|
114 |
+
chat_template=request.chat_template or self.chat_template,
|
115 |
**(request.chat_template_kwargs or {}),
|
116 |
)
|
117 |
except Exception as e:
|
|
|
130 |
logger.error("Error in loading multi-modal data: %s", e)
|
131 |
return self.create_error_response(str(e))
|
132 |
|
133 |
+
request_id = f"chat-{random_uuid()}"
|
134 |
try:
|
|
|
|
|
|
|
|
|
|
|
135 |
sampling_params = request.to_sampling_params()
|
|
|
136 |
decoding_config = await self.engine.get_decoding_config()
|
137 |
guided_decoding_backend = request.guided_decoding_backend \
|
138 |
or decoding_config.guided_decoding_backend
|
139 |
guided_decode_logits_processor = (
|
140 |
+
await
|
141 |
+
get_guided_decoding_logits_processor(guided_decoding_backend,
|
142 |
+
request, tokenizer))
|
143 |
if guided_decode_logits_processor:
|
144 |
if sampling_params.logits_processors is None:
|
145 |
sampling_params.logits_processors = []
|
146 |
sampling_params.logits_processors.append(
|
147 |
guided_decode_logits_processor)
|
148 |
+
|
149 |
+
prompt_inputs = self._tokenize_prompt_input(
|
150 |
+
request,
|
151 |
+
tokenizer,
|
152 |
+
prompt,
|
153 |
+
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
|
154 |
+
add_special_tokens=request.add_special_tokens,
|
155 |
+
)
|
156 |
+
|
157 |
+
self._log_inputs(request_id,
|
158 |
+
prompt_inputs,
|
159 |
+
params=sampling_params,
|
160 |
+
lora_request=lora_request,
|
161 |
+
prompt_adapter_request=prompt_adapter_request)
|
162 |
+
|
163 |
+
engine_inputs: PromptInputs = {
|
164 |
+
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
165 |
+
}
|
166 |
+
if mm_data is not None:
|
167 |
+
engine_inputs["multi_modal_data"] = mm_data
|
168 |
+
|
169 |
+
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
170 |
+
trace_headers = None
|
171 |
+
if is_tracing_enabled and raw_request:
|
172 |
+
trace_headers = extract_trace_headers(raw_request.headers)
|
173 |
+
if (not is_tracing_enabled and raw_request
|
174 |
+
and contains_trace_headers(raw_request.headers)):
|
175 |
+
log_tracing_disabled_warning()
|
176 |
+
|
177 |
+
result_generator = self.engine.generate(
|
178 |
+
engine_inputs,
|
179 |
+
sampling_params,
|
180 |
+
request_id,
|
181 |
+
lora_request=lora_request,
|
182 |
+
trace_headers=trace_headers,
|
183 |
+
prompt_adapter_request=prompt_adapter_request,
|
184 |
+
)
|
185 |
except ValueError as e:
|
186 |
+
# TODO: Use a vllm-specific Validation Error
|
187 |
return self.create_error_response(str(e))
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
# Streaming response
|
190 |
if request.stream:
|
191 |
return self.chat_completion_stream_generator(
|
192 |
+
request, result_generator, request_id, conversation, tokenizer)
|
193 |
else:
|
194 |
try:
|
195 |
return await self.chat_completion_full_generator(
|
196 |
request, raw_request, result_generator, request_id,
|
197 |
+
conversation, tokenizer)
|
198 |
except ValueError as e:
|
199 |
# TODO: Use a vllm-specific Validation Error
|
200 |
return self.create_error_response(str(e))
|
|
|
206 |
return request.messages[-1]["role"]
|
207 |
|
208 |
async def chat_completion_stream_generator(
|
209 |
+
self,
|
210 |
+
request: ChatCompletionRequest,
|
211 |
+
result_generator: AsyncIterator[RequestOutput],
|
212 |
+
request_id: str,
|
213 |
+
conversation: List[ConversationMessage],
|
214 |
+
tokenizer: PreTrainedTokenizer,
|
215 |
) -> AsyncGenerator[str, None]:
|
216 |
model_name = self.served_model_names[0]
|
217 |
created_time = int(time.time())
|
|
|
219 |
first_iteration = True
|
220 |
|
221 |
# Send response for each token for each request.n (index)
|
222 |
+
num_choices = 1 if request.n is None else request.n
|
223 |
+
previous_texts = [""] * num_choices
|
224 |
+
previous_num_tokens = [0] * num_choices
|
225 |
+
finish_reason_sent = [False] * num_choices
|
226 |
+
|
227 |
try:
|
228 |
async for res in result_generator:
|
229 |
# We need to do it here, because if there are exceptions in
|
|
|
233 |
# Send first response for each request.n (index) with
|
234 |
# the role
|
235 |
role = self.get_chat_request_role(request)
|
236 |
+
for i in range(num_choices):
|
237 |
choice_data = ChatCompletionResponseStreamChoice(
|
238 |
index=i,
|
239 |
delta=DeltaMessage(role=role),
|
|
|
261 |
last_msg_content = conversation[-1]["content"]
|
262 |
|
263 |
if last_msg_content:
|
264 |
+
for i in range(num_choices):
|
265 |
choice_data = (
|
266 |
ChatCompletionResponseStreamChoice(
|
267 |
index=i,
|
268 |
delta=DeltaMessage(
|
269 |
content=last_msg_content),
|
270 |
+
logprobs=None,
|
271 |
finish_reason=None))
|
272 |
chunk = ChatCompletionStreamResponse(
|
273 |
id=request_id,
|
274 |
object=chunk_object_type,
|
275 |
created=created_time,
|
276 |
choices=[choice_data],
|
|
|
277 |
model=model_name)
|
278 |
if (request.stream_options and
|
279 |
request.stream_options.include_usage):
|
|
|
299 |
logprobs = self._create_chat_logprobs(
|
300 |
token_ids=delta_token_ids,
|
301 |
top_logprobs=out_logprobs,
|
302 |
+
tokenizer=tokenizer,
|
303 |
num_output_top_logprobs=request.top_logprobs,
|
304 |
)
|
305 |
else:
|
|
|
388 |
yield "data: [DONE]\n\n"
|
389 |
|
390 |
async def chat_completion_full_generator(
|
391 |
+
self,
|
392 |
+
request: ChatCompletionRequest,
|
393 |
+
raw_request: Optional[Request],
|
394 |
+
result_generator: AsyncIterator[RequestOutput],
|
395 |
+
request_id: str,
|
396 |
+
conversation: List[ConversationMessage],
|
397 |
+
tokenizer: PreTrainedTokenizer,
|
398 |
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
399 |
|
400 |
model_name = self.served_model_names[0]
|
|
|
422 |
token_ids=token_ids,
|
423 |
top_logprobs=out_logprobs,
|
424 |
num_output_top_logprobs=request.top_logprobs,
|
425 |
+
tokenizer=tokenizer,
|
426 |
)
|
427 |
else:
|
428 |
logprobs = None
|
|
|
477 |
return response
|
478 |
|
479 |
def _get_top_logprobs(
|
480 |
+
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
481 |
+
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
482 |
return [
|
483 |
ChatCompletionLogProb(
|
484 |
+
token=(token := self._get_decoded_token(p[1], p[0],
|
485 |
+
tokenizer)),
|
486 |
logprob=max(p[1].logprob, -9999.0),
|
487 |
+
bytes=list(token.encode("utf-8", errors="replace")))
|
|
|
|
|
|
|
488 |
for i, p in enumerate(logprobs.items())
|
489 |
if top_logprobs and i < top_logprobs
|
490 |
]
|
|
|
493 |
self,
|
494 |
token_ids: GenericSequence[int],
|
495 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
496 |
+
tokenizer: PreTrainedTokenizer,
|
497 |
num_output_top_logprobs: Optional[int] = None,
|
498 |
) -> ChatCompletionLogProbs:
|
499 |
"""Create OpenAI-style logprobs."""
|
|
|
503 |
for i, token_id in enumerate(token_ids):
|
504 |
step_top_logprobs = top_logprobs[i]
|
505 |
if step_top_logprobs is None:
|
506 |
+
token = tokenizer.decode(token_id)
|
507 |
logprobs_content.append(
|
508 |
ChatCompletionLogProbsContent(
|
509 |
+
token=token,
|
510 |
+
bytes=list(token.encode("utf-8", errors="replace"))))
|
|
|
|
|
511 |
else:
|
512 |
logprobs_content.append(
|
513 |
ChatCompletionLogProbsContent(
|
|
|
518 |
step_top_logprobs[token_id].decoded_token.encode(
|
519 |
"utf-8", errors="replace")),
|
520 |
top_logprobs=self._get_top_logprobs(
|
521 |
+
step_top_logprobs, num_output_top_logprobs,
|
522 |
+
tokenizer)))
|
523 |
|
524 |
return ChatCompletionLogProbs(content=logprobs_content)
|
serving_completion.py
CHANGED
@@ -2,12 +2,14 @@ import time
|
|
2 |
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
3 |
Optional)
|
4 |
from typing import Sequence as GenericSequence
|
5 |
-
from typing import Tuple
|
6 |
|
7 |
from fastapi import Request
|
|
|
8 |
|
9 |
from vllm.config import ModelConfig
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
11 |
# yapf conflicts with isort for this block
|
12 |
# yapf: disable
|
13 |
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
@@ -16,13 +18,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|
16 |
CompletionResponseChoice,
|
17 |
CompletionResponseStreamChoice,
|
18 |
CompletionStreamResponse,
|
19 |
-
|
20 |
-
DetokenizeResponse,
|
21 |
-
TokenizeRequest,
|
22 |
-
TokenizeResponse, UsageInfo)
|
23 |
# yapf: enable
|
24 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
25 |
-
OpenAIServing
|
|
|
26 |
from vllm.logger import init_logger
|
27 |
from vllm.model_executor.guided_decoding import (
|
28 |
get_guided_decoding_logits_processor)
|
@@ -40,38 +40,24 @@ TypeCreateLogProbsFn = Callable[
|
|
40 |
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
41 |
|
42 |
|
43 |
-
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
44 |
-
# get the prompt, openai supports the following
|
45 |
-
# "a string, array of strings, array of tokens, or array of token arrays."
|
46 |
-
prompt_is_tokens = False
|
47 |
-
prompts = [prompt] # case 1: a string
|
48 |
-
if isinstance(prompt, list):
|
49 |
-
if len(prompt) == 0:
|
50 |
-
raise ValueError("please provide at least one prompt")
|
51 |
-
elif isinstance(prompt[0], str):
|
52 |
-
prompt_is_tokens = False
|
53 |
-
prompts = prompt # case 2: array of strings
|
54 |
-
elif isinstance(prompt[0], int):
|
55 |
-
prompt_is_tokens = True
|
56 |
-
prompts = [prompt] # case 3: array of tokens
|
57 |
-
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
58 |
-
prompt_is_tokens = True
|
59 |
-
prompts = prompt # case 4: array of token arrays
|
60 |
-
else:
|
61 |
-
raise ValueError("prompt must be a string, array of strings, "
|
62 |
-
"array of tokens, or array of token arrays")
|
63 |
-
return prompt_is_tokens, prompts
|
64 |
-
|
65 |
-
|
66 |
class OpenAIServingCompletion(OpenAIServing):
|
67 |
|
68 |
-
def __init__(
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
super().__init__(engine=engine,
|
72 |
model_config=model_config,
|
73 |
served_model_names=served_model_names,
|
74 |
-
lora_modules=lora_modules
|
|
|
|
|
75 |
|
76 |
async def create_completion(self, request: CompletionRequest,
|
77 |
raw_request: Request):
|
@@ -100,36 +86,45 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
100 |
# Schedule the request and get the result generator.
|
101 |
generators: List[AsyncIterator[RequestOutput]] = []
|
102 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
sampling_params = request.to_sampling_params()
|
104 |
-
lora_request = self._maybe_get_lora(request)
|
105 |
decoding_config = await self.engine.get_decoding_config()
|
106 |
guided_decoding_backend = request.guided_decoding_backend \
|
107 |
or decoding_config.guided_decoding_backend
|
108 |
guided_decode_logit_processor = (
|
109 |
-
await
|
110 |
-
|
111 |
-
|
112 |
if guided_decode_logit_processor is not None:
|
113 |
if sampling_params.logits_processors is None:
|
114 |
sampling_params.logits_processors = []
|
115 |
sampling_params.logits_processors.append(
|
116 |
guided_decode_logit_processor)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
|
134 |
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
135 |
trace_headers = None
|
@@ -140,13 +135,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
140 |
log_tracing_disabled_warning()
|
141 |
|
142 |
generator = self.engine.generate(
|
143 |
-
{
|
144 |
-
"prompt": prompt_text,
|
145 |
-
"prompt_token_ids": prompt_ids
|
146 |
-
},
|
147 |
sampling_params,
|
148 |
-
|
149 |
lora_request=lora_request,
|
|
|
150 |
trace_headers=trace_headers,
|
151 |
)
|
152 |
|
@@ -173,7 +166,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
173 |
request_id,
|
174 |
created_time,
|
175 |
model_name,
|
176 |
-
num_prompts=len(prompts)
|
|
|
177 |
|
178 |
# Non-streaming response
|
179 |
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
@@ -184,8 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
184 |
await self.engine.abort(f"{request_id}-{i}")
|
185 |
return self.create_error_response("Client disconnected")
|
186 |
final_res_batch[i] = res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
response = self.request_output_to_completion_response(
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
except ValueError as e:
|
190 |
# TODO: Use a vllm-specific Validation Error
|
191 |
return self.create_error_response(str(e))
|
@@ -212,11 +225,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
212 |
created_time: int,
|
213 |
model_name: str,
|
214 |
num_prompts: int,
|
|
|
215 |
) -> AsyncGenerator[str, None]:
|
216 |
-
|
217 |
-
previous_texts = [""] *
|
218 |
-
previous_num_tokens = [0] *
|
219 |
-
has_echoed = [False] *
|
220 |
|
221 |
try:
|
222 |
async for prompt_idx, res in result_generator:
|
@@ -227,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
227 |
raise StopAsyncIteration()
|
228 |
|
229 |
for output in res.outputs:
|
230 |
-
i = output.index + prompt_idx *
|
231 |
# TODO(simon): optimize the performance by avoiding full
|
232 |
# text O(n^2) sending.
|
233 |
|
@@ -262,6 +276,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
262 |
token_ids=delta_token_ids,
|
263 |
top_logprobs=out_logprobs,
|
264 |
num_output_top_logprobs=request.logprobs,
|
|
|
265 |
initial_text_offset=len(previous_texts[i]),
|
266 |
)
|
267 |
else:
|
@@ -301,7 +316,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
301 |
else:
|
302 |
chunk.usage = None
|
303 |
|
304 |
-
response_json = chunk.model_dump_json(exclude_unset=
|
305 |
yield f"data: {response_json}\n\n"
|
306 |
|
307 |
if (request.stream_options
|
@@ -314,7 +329,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
314 |
usage=usage,
|
315 |
)
|
316 |
final_usage_data = (final_usage_chunk.model_dump_json(
|
317 |
-
exclude_unset=
|
318 |
yield f"data: {final_usage_data}\n\n"
|
319 |
|
320 |
except ValueError as e:
|
@@ -330,12 +345,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
330 |
request_id: str,
|
331 |
created_time: int,
|
332 |
model_name: str,
|
|
|
333 |
) -> CompletionResponse:
|
334 |
choices: List[CompletionResponseChoice] = []
|
335 |
num_prompt_tokens = 0
|
336 |
num_generated_tokens = 0
|
|
|
337 |
for final_res in final_res_batch:
|
338 |
-
assert final_res is not None
|
339 |
prompt_token_ids = final_res.prompt_token_ids
|
340 |
prompt_logprobs = final_res.prompt_logprobs
|
341 |
prompt_text = final_res.prompt
|
@@ -361,6 +377,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
361 |
logprobs = self._create_completion_logprobs(
|
362 |
token_ids=token_ids,
|
363 |
top_logprobs=out_logprobs,
|
|
|
364 |
num_output_top_logprobs=request.logprobs,
|
365 |
)
|
366 |
else:
|
@@ -398,6 +415,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
398 |
token_ids: GenericSequence[int],
|
399 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
400 |
num_output_top_logprobs: int,
|
|
|
401 |
initial_text_offset: int = 0,
|
402 |
) -> CompletionLogProbs:
|
403 |
"""Create logprobs for OpenAI Completion API."""
|
@@ -411,13 +429,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
411 |
for i, token_id in enumerate(token_ids):
|
412 |
step_top_logprobs = top_logprobs[i]
|
413 |
if step_top_logprobs is None:
|
414 |
-
token =
|
415 |
out_tokens.append(token)
|
416 |
out_token_logprobs.append(None)
|
417 |
out_top_logprobs.append(None)
|
418 |
else:
|
419 |
token = self._get_decoded_token(step_top_logprobs[token_id],
|
420 |
-
token_id)
|
421 |
token_logprob = max(step_top_logprobs[token_id].logprob,
|
422 |
-9999.0)
|
423 |
out_tokens.append(token)
|
@@ -430,7 +448,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
430 |
out_top_logprobs.append({
|
431 |
# Convert float("-inf") to the
|
432 |
# JSON-serializable float that OpenAI uses
|
433 |
-
self._get_decoded_token(top_lp[1], top_lp[0]):
|
434 |
max(top_lp[1].logprob, -9999.0)
|
435 |
for i, top_lp in enumerate(step_top_logprobs.items())
|
436 |
if num_output_top_logprobs >= i
|
@@ -447,30 +465,4 @@ class OpenAIServingCompletion(OpenAIServing):
|
|
447 |
token_logprobs=out_token_logprobs,
|
448 |
tokens=out_tokens,
|
449 |
top_logprobs=out_top_logprobs,
|
450 |
-
)
|
451 |
-
|
452 |
-
async def create_tokenize(self,
|
453 |
-
request: TokenizeRequest) -> TokenizeResponse:
|
454 |
-
error_check_ret = await self._check_model(request)
|
455 |
-
if error_check_ret is not None:
|
456 |
-
return error_check_ret
|
457 |
-
|
458 |
-
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
459 |
-
request,
|
460 |
-
prompt=request.prompt,
|
461 |
-
add_special_tokens=request.add_special_tokens)
|
462 |
-
|
463 |
-
return TokenizeResponse(tokens=input_ids,
|
464 |
-
count=len(input_ids),
|
465 |
-
max_model_len=self.max_model_len)
|
466 |
-
|
467 |
-
async def create_detokenize(
|
468 |
-
self, request: DetokenizeRequest) -> DetokenizeResponse:
|
469 |
-
error_check_ret = await self._check_model(request)
|
470 |
-
if error_check_ret is not None:
|
471 |
-
return error_check_ret
|
472 |
-
|
473 |
-
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
474 |
-
request, prompt_ids=request.tokens)
|
475 |
-
|
476 |
-
return DetokenizeResponse(prompt=input_text)
|
|
|
2 |
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
3 |
Optional)
|
4 |
from typing import Sequence as GenericSequence
|
5 |
+
from typing import Tuple, cast
|
6 |
|
7 |
from fastapi import Request
|
8 |
+
from transformers import PreTrainedTokenizer
|
9 |
|
10 |
from vllm.config import ModelConfig
|
11 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
12 |
+
from vllm.entrypoints.logger import RequestLogger
|
13 |
# yapf conflicts with isort for this block
|
14 |
# yapf: disable
|
15 |
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|
|
18 |
CompletionResponseChoice,
|
19 |
CompletionResponseStreamChoice,
|
20 |
CompletionStreamResponse,
|
21 |
+
UsageInfo)
|
|
|
|
|
|
|
22 |
# yapf: enable
|
23 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
24 |
+
OpenAIServing,
|
25 |
+
PromptAdapterPath)
|
26 |
from vllm.logger import init_logger
|
27 |
from vllm.model_executor.guided_decoding import (
|
28 |
get_guided_decoding_logits_processor)
|
|
|
40 |
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class OpenAIServingCompletion(OpenAIServing):
|
44 |
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
engine: AsyncLLMEngine,
|
48 |
+
model_config: ModelConfig,
|
49 |
+
served_model_names: List[str],
|
50 |
+
*,
|
51 |
+
lora_modules: Optional[List[LoRAModulePath]],
|
52 |
+
prompt_adapters: Optional[List[PromptAdapterPath]],
|
53 |
+
request_logger: Optional[RequestLogger],
|
54 |
+
):
|
55 |
super().__init__(engine=engine,
|
56 |
model_config=model_config,
|
57 |
served_model_names=served_model_names,
|
58 |
+
lora_modules=lora_modules,
|
59 |
+
prompt_adapters=prompt_adapters,
|
60 |
+
request_logger=request_logger)
|
61 |
|
62 |
async def create_completion(self, request: CompletionRequest,
|
63 |
raw_request: Request):
|
|
|
86 |
# Schedule the request and get the result generator.
|
87 |
generators: List[AsyncIterator[RequestOutput]] = []
|
88 |
try:
|
89 |
+
(
|
90 |
+
lora_request,
|
91 |
+
prompt_adapter_request,
|
92 |
+
) = self._maybe_get_adapters(request)
|
93 |
+
|
94 |
+
tokenizer = await self.engine.get_tokenizer(lora_request)
|
95 |
+
|
96 |
sampling_params = request.to_sampling_params()
|
|
|
97 |
decoding_config = await self.engine.get_decoding_config()
|
98 |
guided_decoding_backend = request.guided_decoding_backend \
|
99 |
or decoding_config.guided_decoding_backend
|
100 |
guided_decode_logit_processor = (
|
101 |
+
await
|
102 |
+
get_guided_decoding_logits_processor(guided_decoding_backend,
|
103 |
+
request, tokenizer))
|
104 |
if guided_decode_logit_processor is not None:
|
105 |
if sampling_params.logits_processors is None:
|
106 |
sampling_params.logits_processors = []
|
107 |
sampling_params.logits_processors.append(
|
108 |
guided_decode_logit_processor)
|
109 |
+
|
110 |
+
prompts = list(
|
111 |
+
self._tokenize_prompt_input_or_inputs(
|
112 |
+
request,
|
113 |
+
tokenizer,
|
114 |
+
request.prompt,
|
115 |
+
truncate_prompt_tokens=sampling_params.
|
116 |
+
truncate_prompt_tokens,
|
117 |
+
add_special_tokens=request.add_special_tokens,
|
118 |
+
))
|
119 |
+
|
120 |
+
for i, prompt_inputs in enumerate(prompts):
|
121 |
+
request_id_item = f"{request_id}-{i}"
|
122 |
+
|
123 |
+
self._log_inputs(request_id_item,
|
124 |
+
prompt_inputs,
|
125 |
+
params=sampling_params,
|
126 |
+
lora_request=lora_request,
|
127 |
+
prompt_adapter_request=prompt_adapter_request)
|
128 |
|
129 |
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
130 |
trace_headers = None
|
|
|
135 |
log_tracing_disabled_warning()
|
136 |
|
137 |
generator = self.engine.generate(
|
138 |
+
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
|
|
|
|
|
|
139 |
sampling_params,
|
140 |
+
request_id_item,
|
141 |
lora_request=lora_request,
|
142 |
+
prompt_adapter_request=prompt_adapter_request,
|
143 |
trace_headers=trace_headers,
|
144 |
)
|
145 |
|
|
|
166 |
request_id,
|
167 |
created_time,
|
168 |
model_name,
|
169 |
+
num_prompts=len(prompts),
|
170 |
+
tokenizer=tokenizer)
|
171 |
|
172 |
# Non-streaming response
|
173 |
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
|
|
178 |
await self.engine.abort(f"{request_id}-{i}")
|
179 |
return self.create_error_response("Client disconnected")
|
180 |
final_res_batch[i] = res
|
181 |
+
|
182 |
+
for i, final_res in enumerate(final_res_batch):
|
183 |
+
assert final_res is not None
|
184 |
+
|
185 |
+
# The output should contain the input text
|
186 |
+
# We did not pass it into vLLM engine to avoid being redundant
|
187 |
+
# with the inputs token IDs
|
188 |
+
if final_res.prompt is None:
|
189 |
+
final_res.prompt = prompts[i]["prompt"]
|
190 |
+
|
191 |
+
final_res_batch_checked = cast(List[RequestOutput],
|
192 |
+
final_res_batch)
|
193 |
+
|
194 |
response = self.request_output_to_completion_response(
|
195 |
+
final_res_batch_checked,
|
196 |
+
request,
|
197 |
+
request_id,
|
198 |
+
created_time,
|
199 |
+
model_name,
|
200 |
+
tokenizer,
|
201 |
+
)
|
202 |
except ValueError as e:
|
203 |
# TODO: Use a vllm-specific Validation Error
|
204 |
return self.create_error_response(str(e))
|
|
|
225 |
created_time: int,
|
226 |
model_name: str,
|
227 |
num_prompts: int,
|
228 |
+
tokenizer: PreTrainedTokenizer,
|
229 |
) -> AsyncGenerator[str, None]:
|
230 |
+
num_choices = 1 if request.n is None else request.n
|
231 |
+
previous_texts = [""] * num_choices * num_prompts
|
232 |
+
previous_num_tokens = [0] * num_choices * num_prompts
|
233 |
+
has_echoed = [False] * num_choices * num_prompts
|
234 |
|
235 |
try:
|
236 |
async for prompt_idx, res in result_generator:
|
|
|
241 |
raise StopAsyncIteration()
|
242 |
|
243 |
for output in res.outputs:
|
244 |
+
i = output.index + prompt_idx * num_choices
|
245 |
# TODO(simon): optimize the performance by avoiding full
|
246 |
# text O(n^2) sending.
|
247 |
|
|
|
276 |
token_ids=delta_token_ids,
|
277 |
top_logprobs=out_logprobs,
|
278 |
num_output_top_logprobs=request.logprobs,
|
279 |
+
tokenizer=tokenizer,
|
280 |
initial_text_offset=len(previous_texts[i]),
|
281 |
)
|
282 |
else:
|
|
|
316 |
else:
|
317 |
chunk.usage = None
|
318 |
|
319 |
+
response_json = chunk.model_dump_json(exclude_unset=False)
|
320 |
yield f"data: {response_json}\n\n"
|
321 |
|
322 |
if (request.stream_options
|
|
|
329 |
usage=usage,
|
330 |
)
|
331 |
final_usage_data = (final_usage_chunk.model_dump_json(
|
332 |
+
exclude_unset=False, exclude_none=True))
|
333 |
yield f"data: {final_usage_data}\n\n"
|
334 |
|
335 |
except ValueError as e:
|
|
|
345 |
request_id: str,
|
346 |
created_time: int,
|
347 |
model_name: str,
|
348 |
+
tokenizer: PreTrainedTokenizer,
|
349 |
) -> CompletionResponse:
|
350 |
choices: List[CompletionResponseChoice] = []
|
351 |
num_prompt_tokens = 0
|
352 |
num_generated_tokens = 0
|
353 |
+
|
354 |
for final_res in final_res_batch:
|
|
|
355 |
prompt_token_ids = final_res.prompt_token_ids
|
356 |
prompt_logprobs = final_res.prompt_logprobs
|
357 |
prompt_text = final_res.prompt
|
|
|
377 |
logprobs = self._create_completion_logprobs(
|
378 |
token_ids=token_ids,
|
379 |
top_logprobs=out_logprobs,
|
380 |
+
tokenizer=tokenizer,
|
381 |
num_output_top_logprobs=request.logprobs,
|
382 |
)
|
383 |
else:
|
|
|
415 |
token_ids: GenericSequence[int],
|
416 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
417 |
num_output_top_logprobs: int,
|
418 |
+
tokenizer: PreTrainedTokenizer,
|
419 |
initial_text_offset: int = 0,
|
420 |
) -> CompletionLogProbs:
|
421 |
"""Create logprobs for OpenAI Completion API."""
|
|
|
429 |
for i, token_id in enumerate(token_ids):
|
430 |
step_top_logprobs = top_logprobs[i]
|
431 |
if step_top_logprobs is None:
|
432 |
+
token = tokenizer.decode(token_id)
|
433 |
out_tokens.append(token)
|
434 |
out_token_logprobs.append(None)
|
435 |
out_top_logprobs.append(None)
|
436 |
else:
|
437 |
token = self._get_decoded_token(step_top_logprobs[token_id],
|
438 |
+
token_id, tokenizer)
|
439 |
token_logprob = max(step_top_logprobs[token_id].logprob,
|
440 |
-9999.0)
|
441 |
out_tokens.append(token)
|
|
|
448 |
out_top_logprobs.append({
|
449 |
# Convert float("-inf") to the
|
450 |
# JSON-serializable float that OpenAI uses
|
451 |
+
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
|
452 |
max(top_lp[1].logprob, -9999.0)
|
453 |
for i, top_lp in enumerate(step_top_logprobs.items())
|
454 |
if num_output_top_logprobs >= i
|
|
|
465 |
token_logprobs=out_token_logprobs,
|
466 |
tokens=out_tokens,
|
467 |
top_logprobs=out_top_logprobs,
|
468 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serving_embedding.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import base64
|
2 |
import time
|
3 |
-
from typing import AsyncIterator, List, Optional, Tuple
|
4 |
|
5 |
import numpy as np
|
6 |
from fastapi import Request
|
7 |
|
8 |
from vllm.config import ModelConfig
|
9 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
10 |
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
11 |
EmbeddingResponse,
|
12 |
EmbeddingResponseData, UsageInfo)
|
13 |
-
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
|
14 |
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
15 |
from vllm.logger import init_logger
|
16 |
from vllm.outputs import EmbeddingRequestOutput
|
@@ -28,11 +28,11 @@ def request_output_to_embedding_response(
|
|
28 |
data: List[EmbeddingResponseData] = []
|
29 |
num_prompt_tokens = 0
|
30 |
for idx, final_res in enumerate(final_res_batch):
|
31 |
-
assert final_res is not None
|
32 |
prompt_token_ids = final_res.prompt_token_ids
|
33 |
embedding = final_res.outputs.embedding
|
34 |
if encoding_format == "base64":
|
35 |
-
|
|
|
36 |
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
37 |
data.append(embedding_data)
|
38 |
|
@@ -54,12 +54,20 @@ def request_output_to_embedding_response(
|
|
54 |
|
55 |
class OpenAIServingEmbedding(OpenAIServing):
|
56 |
|
57 |
-
def __init__(
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
super().__init__(engine=engine,
|
60 |
model_config=model_config,
|
61 |
served_model_names=served_model_names,
|
62 |
-
lora_modules=None
|
|
|
|
|
63 |
self._check_embedding_mode(model_config.embedding_mode)
|
64 |
|
65 |
async def create_embedding(self, request: EmbeddingRequest,
|
@@ -80,32 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|
80 |
"dimensions is currently not supported")
|
81 |
|
82 |
model_name = request.model
|
83 |
-
request_id = f"
|
84 |
created_time = int(time.monotonic())
|
85 |
|
86 |
# Schedule the request and get the result generator.
|
87 |
-
generators = []
|
88 |
try:
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
pooling_params = request.to_pooling_params()
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
|
|
|
|
|
|
101 |
|
102 |
generator = self.engine.encode(
|
103 |
-
{
|
104 |
-
"prompt": prompt_text,
|
105 |
-
"prompt_token_ids": prompt_ids
|
106 |
-
},
|
107 |
pooling_params,
|
108 |
-
|
|
|
109 |
)
|
110 |
|
111 |
generators.append(generator)
|
@@ -124,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|
124 |
if await raw_request.is_disconnected():
|
125 |
# Abort the request if the client disconnects.
|
126 |
await self.engine.abort(f"{request_id}-{i}")
|
127 |
-
# TODO: Use a vllm-specific Validation Error
|
128 |
return self.create_error_response("Client disconnected")
|
129 |
final_res_batch[i] = res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
response = request_output_to_embedding_response(
|
131 |
-
|
132 |
encoding_format)
|
133 |
except ValueError as e:
|
134 |
# TODO: Use a vllm-specific Validation Error
|
|
|
1 |
import base64
|
2 |
import time
|
3 |
+
from typing import AsyncIterator, List, Optional, Tuple, cast
|
4 |
|
5 |
import numpy as np
|
6 |
from fastapi import Request
|
7 |
|
8 |
from vllm.config import ModelConfig
|
9 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
10 |
+
from vllm.entrypoints.logger import RequestLogger
|
11 |
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
12 |
EmbeddingResponse,
|
13 |
EmbeddingResponseData, UsageInfo)
|
|
|
14 |
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
15 |
from vllm.logger import init_logger
|
16 |
from vllm.outputs import EmbeddingRequestOutput
|
|
|
28 |
data: List[EmbeddingResponseData] = []
|
29 |
num_prompt_tokens = 0
|
30 |
for idx, final_res in enumerate(final_res_batch):
|
|
|
31 |
prompt_token_ids = final_res.prompt_token_ids
|
32 |
embedding = final_res.outputs.embedding
|
33 |
if encoding_format == "base64":
|
34 |
+
embedding_bytes = np.array(embedding).tobytes()
|
35 |
+
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
|
36 |
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
37 |
data.append(embedding_data)
|
38 |
|
|
|
54 |
|
55 |
class OpenAIServingEmbedding(OpenAIServing):
|
56 |
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
engine: AsyncLLMEngine,
|
60 |
+
model_config: ModelConfig,
|
61 |
+
served_model_names: List[str],
|
62 |
+
*,
|
63 |
+
request_logger: Optional[RequestLogger],
|
64 |
+
):
|
65 |
super().__init__(engine=engine,
|
66 |
model_config=model_config,
|
67 |
served_model_names=served_model_names,
|
68 |
+
lora_modules=None,
|
69 |
+
prompt_adapters=None,
|
70 |
+
request_logger=request_logger)
|
71 |
self._check_embedding_mode(model_config.embedding_mode)
|
72 |
|
73 |
async def create_embedding(self, request: EmbeddingRequest,
|
|
|
88 |
"dimensions is currently not supported")
|
89 |
|
90 |
model_name = request.model
|
91 |
+
request_id = f"embd-{random_uuid()}"
|
92 |
created_time = int(time.monotonic())
|
93 |
|
94 |
# Schedule the request and get the result generator.
|
95 |
+
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
|
96 |
try:
|
97 |
+
(
|
98 |
+
lora_request,
|
99 |
+
prompt_adapter_request,
|
100 |
+
) = self._maybe_get_adapters(request)
|
101 |
+
|
102 |
+
tokenizer = await self.engine.get_tokenizer(lora_request)
|
103 |
+
|
104 |
pooling_params = request.to_pooling_params()
|
105 |
|
106 |
+
prompts = list(
|
107 |
+
self._tokenize_prompt_input_or_inputs(
|
108 |
+
request,
|
109 |
+
tokenizer,
|
110 |
+
request.input,
|
111 |
+
))
|
112 |
+
|
113 |
+
for i, prompt_inputs in enumerate(prompts):
|
114 |
+
request_id_item = f"{request_id}-{i}"
|
115 |
+
|
116 |
+
self._log_inputs(request_id_item,
|
117 |
+
prompt_inputs,
|
118 |
+
params=pooling_params,
|
119 |
+
lora_request=lora_request,
|
120 |
+
prompt_adapter_request=prompt_adapter_request)
|
121 |
|
122 |
+
if prompt_adapter_request is not None:
|
123 |
+
raise NotImplementedError(
|
124 |
+
"Prompt adapter is not supported "
|
125 |
+
"for embedding models")
|
126 |
|
127 |
generator = self.engine.encode(
|
128 |
+
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
|
|
|
|
|
|
129 |
pooling_params,
|
130 |
+
request_id_item,
|
131 |
+
lora_request=lora_request,
|
132 |
)
|
133 |
|
134 |
generators.append(generator)
|
|
|
147 |
if await raw_request.is_disconnected():
|
148 |
# Abort the request if the client disconnects.
|
149 |
await self.engine.abort(f"{request_id}-{i}")
|
|
|
150 |
return self.create_error_response("Client disconnected")
|
151 |
final_res_batch[i] = res
|
152 |
+
|
153 |
+
for final_res in final_res_batch:
|
154 |
+
assert final_res is not None
|
155 |
+
|
156 |
+
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
157 |
+
final_res_batch)
|
158 |
+
|
159 |
response = request_output_to_embedding_response(
|
160 |
+
final_res_batch_checked, request_id, created_time, model_name,
|
161 |
encoding_format)
|
162 |
except ValueError as e:
|
163 |
# TODO: Use a vllm-specific Validation Error
|
serving_engine.py
CHANGED
@@ -1,65 +1,108 @@
|
|
1 |
import json
|
|
|
2 |
from dataclasses import dataclass
|
3 |
from http import HTTPStatus
|
4 |
-
from typing import
|
5 |
|
6 |
from pydantic import Field
|
|
|
7 |
from typing_extensions import Annotated
|
8 |
|
9 |
from vllm.config import ModelConfig
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
|
|
|
|
11 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
12 |
CompletionRequest,
|
13 |
DetokenizeRequest,
|
14 |
EmbeddingRequest, ErrorResponse,
|
15 |
ModelCard, ModelList,
|
16 |
-
ModelPermission,
|
|
|
|
|
|
|
|
|
|
|
17 |
from vllm.logger import init_logger
|
18 |
from vllm.lora.request import LoRARequest
|
|
|
|
|
|
|
19 |
from vllm.sequence import Logprob
|
20 |
-
from vllm.transformers_utils.tokenizer import get_tokenizer
|
21 |
|
22 |
logger = init_logger(__name__)
|
23 |
|
24 |
|
25 |
@dataclass
|
26 |
-
class
|
27 |
name: str
|
28 |
local_path: str
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
class OpenAIServing:
|
32 |
|
33 |
-
def __init__(
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
super().__init__()
|
37 |
|
38 |
self.engine = engine
|
39 |
self.model_config = model_config
|
40 |
self.max_model_len = model_config.max_model_len
|
41 |
|
42 |
-
# A separate tokenizer to map token IDs to strings.
|
43 |
-
self.tokenizer = get_tokenizer(
|
44 |
-
model_config.tokenizer,
|
45 |
-
tokenizer_mode=model_config.tokenizer_mode,
|
46 |
-
tokenizer_revision=model_config.tokenizer_revision,
|
47 |
-
trust_remote_code=model_config.trust_remote_code,
|
48 |
-
truncation_side="left")
|
49 |
-
|
50 |
self.served_model_names = served_model_names
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
else:
|
55 |
self.lora_requests = [
|
56 |
LoRARequest(
|
57 |
lora_name=lora.name,
|
58 |
lora_int_id=i,
|
59 |
-
|
60 |
) for i, lora in enumerate(lora_modules, start=1)
|
61 |
]
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
async def show_available_models(self) -> ModelList:
|
64 |
"""Show available models. Right now we only have one model."""
|
65 |
model_cards = [
|
@@ -75,7 +118,14 @@ class OpenAIServing:
|
|
75 |
permission=[ModelPermission()])
|
76 |
for lora in self.lora_requests
|
77 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
model_cards.extend(lora_cards)
|
|
|
79 |
return ModelList(data=model_cards)
|
80 |
|
81 |
def create_error_response(
|
@@ -101,71 +151,82 @@ class OpenAIServing:
|
|
101 |
return json_str
|
102 |
|
103 |
async def _check_model(
|
104 |
-
self,
|
105 |
-
|
106 |
-
TokenizeRequest]
|
107 |
) -> Optional[ErrorResponse]:
|
108 |
if request.model in self.served_model_names:
|
109 |
return None
|
110 |
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
111 |
return None
|
|
|
|
|
|
|
|
|
|
|
112 |
return self.create_error_response(
|
113 |
message=f"The model `{request.model}` does not exist.",
|
114 |
err_type="NotFoundError",
|
115 |
status_code=HTTPStatus.NOT_FOUND)
|
116 |
|
117 |
-
def
|
118 |
-
self, request:
|
119 |
-
|
120 |
-
|
121 |
if request.model in self.served_model_names:
|
122 |
-
return None
|
123 |
for lora in self.lora_requests:
|
124 |
if request.model == lora.lora_name:
|
125 |
-
return lora
|
|
|
|
|
|
|
126 |
# if _check_model has been called earlier, this will be unreachable
|
127 |
raise ValueError(f"The model `{request.model}` does not exist.")
|
128 |
|
129 |
-
def
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
) -> Tuple[List[int], str]:
|
140 |
-
if not (prompt or prompt_ids):
|
141 |
-
raise ValueError("Either prompt or prompt_ids should be provided.")
|
142 |
-
if (prompt and prompt_ids):
|
143 |
-
raise ValueError(
|
144 |
-
"Only one of prompt or prompt_ids should be provided.")
|
145 |
-
|
146 |
-
if prompt_ids is None:
|
147 |
-
# When using OpenAIServingChat for chat completions, for
|
148 |
-
# most models the special tokens (e.g., BOS) have already
|
149 |
-
# been added by the chat template. Therefore, we do not
|
150 |
-
# need to add them again.
|
151 |
-
# Set add_special_tokens to False (by default) to avoid
|
152 |
-
# adding the BOS tokens again.
|
153 |
-
tokenizer_kwargs: Dict[str, Any] = {
|
154 |
-
"add_special_tokens": add_special_tokens
|
155 |
-
}
|
156 |
-
if truncate_prompt_tokens is not None:
|
157 |
-
tokenizer_kwargs.update({
|
158 |
-
"truncation": True,
|
159 |
-
"max_length": truncate_prompt_tokens,
|
160 |
-
})
|
161 |
-
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
162 |
-
elif truncate_prompt_tokens is not None:
|
163 |
-
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
164 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
input_ids = prompt_ids
|
|
|
|
|
|
|
|
|
166 |
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
token_num = len(input_ids)
|
170 |
|
171 |
# Note: EmbeddingRequest doesn't have max_tokens
|
@@ -175,13 +236,16 @@ class OpenAIServing:
|
|
175 |
f"This model's maximum context length is "
|
176 |
f"{self.max_model_len} tokens. However, you requested "
|
177 |
f"{token_num} tokens in the input for embedding "
|
178 |
-
f"generation. Please reduce the length of the input."
|
179 |
-
return
|
|
|
180 |
|
181 |
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
182 |
# and does not require model context length validation
|
183 |
-
if isinstance(request, (
|
184 |
-
|
|
|
|
|
185 |
|
186 |
if request.max_tokens is None:
|
187 |
if token_num >= self.max_model_len:
|
@@ -189,7 +253,7 @@ class OpenAIServing:
|
|
189 |
f"This model's maximum context length is "
|
190 |
f"{self.max_model_len} tokens. However, you requested "
|
191 |
f"{token_num} tokens in the messages, "
|
192 |
-
f"Please reduce the length of the messages."
|
193 |
request.max_tokens = self.max_model_len - token_num
|
194 |
|
195 |
if token_num + request.max_tokens > self.max_model_len:
|
@@ -199,11 +263,132 @@ class OpenAIServing:
|
|
199 |
f"{request.max_tokens + token_num} tokens "
|
200 |
f"({token_num} in the messages, "
|
201 |
f"{request.max_tokens} in the completion). "
|
202 |
-
f"Please reduce the length of the messages or completion."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
else:
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
207 |
if logprob.decoded_token is not None:
|
208 |
return logprob.decoded_token
|
209 |
-
return
|
|
|
1 |
import json
|
2 |
+
import pathlib
|
3 |
from dataclasses import dataclass
|
4 |
from http import HTTPStatus
|
5 |
+
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
6 |
|
7 |
from pydantic import Field
|
8 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
9 |
from typing_extensions import Annotated
|
10 |
|
11 |
from vllm.config import ModelConfig
|
12 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
13 |
+
from vllm.entrypoints.logger import RequestLogger
|
14 |
+
# yapf conflicts with isort for this block
|
15 |
+
# yapf: disable
|
16 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
17 |
CompletionRequest,
|
18 |
DetokenizeRequest,
|
19 |
EmbeddingRequest, ErrorResponse,
|
20 |
ModelCard, ModelList,
|
21 |
+
ModelPermission,
|
22 |
+
TokenizeChatRequest,
|
23 |
+
TokenizeCompletionRequest,
|
24 |
+
TokenizeRequest)
|
25 |
+
# yapf: enable
|
26 |
+
from vllm.inputs import parse_and_batch_prompt
|
27 |
from vllm.logger import init_logger
|
28 |
from vllm.lora.request import LoRARequest
|
29 |
+
from vllm.pooling_params import PoolingParams
|
30 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
31 |
+
from vllm.sampling_params import SamplingParams
|
32 |
from vllm.sequence import Logprob
|
|
|
33 |
|
34 |
logger = init_logger(__name__)
|
35 |
|
36 |
|
37 |
@dataclass
|
38 |
+
class PromptAdapterPath:
|
39 |
name: str
|
40 |
local_path: str
|
41 |
|
42 |
|
43 |
+
@dataclass
|
44 |
+
class LoRAModulePath:
|
45 |
+
name: str
|
46 |
+
path: str
|
47 |
+
|
48 |
+
|
49 |
+
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
50 |
+
EmbeddingRequest, TokenizeRequest]
|
51 |
+
|
52 |
+
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
53 |
+
|
54 |
+
|
55 |
+
class TextTokensPrompt(TypedDict):
|
56 |
+
prompt: str
|
57 |
+
prompt_token_ids: List[int]
|
58 |
+
|
59 |
+
|
60 |
class OpenAIServing:
|
61 |
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
engine: AsyncLLMEngine,
|
65 |
+
model_config: ModelConfig,
|
66 |
+
served_model_names: List[str],
|
67 |
+
*,
|
68 |
+
lora_modules: Optional[List[LoRAModulePath]],
|
69 |
+
prompt_adapters: Optional[List[PromptAdapterPath]],
|
70 |
+
request_logger: Optional[RequestLogger],
|
71 |
+
):
|
72 |
super().__init__()
|
73 |
|
74 |
self.engine = engine
|
75 |
self.model_config = model_config
|
76 |
self.max_model_len = model_config.max_model_len
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
self.served_model_names = served_model_names
|
79 |
|
80 |
+
self.lora_requests = []
|
81 |
+
if lora_modules is not None:
|
|
|
82 |
self.lora_requests = [
|
83 |
LoRARequest(
|
84 |
lora_name=lora.name,
|
85 |
lora_int_id=i,
|
86 |
+
lora_path=lora.path,
|
87 |
) for i, lora in enumerate(lora_modules, start=1)
|
88 |
]
|
89 |
|
90 |
+
self.prompt_adapter_requests = []
|
91 |
+
if prompt_adapters is not None:
|
92 |
+
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
93 |
+
with pathlib.Path(prompt_adapter.local_path,
|
94 |
+
"adapter_config.json").open() as f:
|
95 |
+
adapter_config = json.load(f)
|
96 |
+
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
97 |
+
self.prompt_adapter_requests.append(
|
98 |
+
PromptAdapterRequest(
|
99 |
+
prompt_adapter_name=prompt_adapter.name,
|
100 |
+
prompt_adapter_id=i,
|
101 |
+
prompt_adapter_local_path=prompt_adapter.local_path,
|
102 |
+
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
103 |
+
|
104 |
+
self.request_logger = request_logger
|
105 |
+
|
106 |
async def show_available_models(self) -> ModelList:
|
107 |
"""Show available models. Right now we only have one model."""
|
108 |
model_cards = [
|
|
|
118 |
permission=[ModelPermission()])
|
119 |
for lora in self.lora_requests
|
120 |
]
|
121 |
+
prompt_adapter_cards = [
|
122 |
+
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
123 |
+
root=self.served_model_names[0],
|
124 |
+
permission=[ModelPermission()])
|
125 |
+
for prompt_adapter in self.prompt_adapter_requests
|
126 |
+
]
|
127 |
model_cards.extend(lora_cards)
|
128 |
+
model_cards.extend(prompt_adapter_cards)
|
129 |
return ModelList(data=model_cards)
|
130 |
|
131 |
def create_error_response(
|
|
|
151 |
return json_str
|
152 |
|
153 |
async def _check_model(
|
154 |
+
self,
|
155 |
+
request: AnyRequest,
|
|
|
156 |
) -> Optional[ErrorResponse]:
|
157 |
if request.model in self.served_model_names:
|
158 |
return None
|
159 |
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
160 |
return None
|
161 |
+
if request.model in [
|
162 |
+
prompt_adapter.prompt_adapter_name
|
163 |
+
for prompt_adapter in self.prompt_adapter_requests
|
164 |
+
]:
|
165 |
+
return None
|
166 |
return self.create_error_response(
|
167 |
message=f"The model `{request.model}` does not exist.",
|
168 |
err_type="NotFoundError",
|
169 |
status_code=HTTPStatus.NOT_FOUND)
|
170 |
|
171 |
+
def _maybe_get_adapters(
|
172 |
+
self, request: AnyRequest
|
173 |
+
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
174 |
+
None, PromptAdapterRequest]]:
|
175 |
if request.model in self.served_model_names:
|
176 |
+
return None, None
|
177 |
for lora in self.lora_requests:
|
178 |
if request.model == lora.lora_name:
|
179 |
+
return lora, None
|
180 |
+
for prompt_adapter in self.prompt_adapter_requests:
|
181 |
+
if request.model == prompt_adapter.prompt_adapter_name:
|
182 |
+
return None, prompt_adapter
|
183 |
# if _check_model has been called earlier, this will be unreachable
|
184 |
raise ValueError(f"The model `{request.model}` does not exist.")
|
185 |
|
186 |
+
def _normalize_prompt_text_to_input(
|
187 |
+
self,
|
188 |
+
request: AnyRequest,
|
189 |
+
tokenizer: AnyTokenizer,
|
190 |
+
prompt: str,
|
191 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
192 |
+
add_special_tokens: bool,
|
193 |
+
) -> TextTokensPrompt:
|
194 |
+
if truncate_prompt_tokens is None:
|
195 |
+
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
else:
|
197 |
+
encoded = tokenizer(prompt,
|
198 |
+
add_special_tokens=add_special_tokens,
|
199 |
+
truncation=True,
|
200 |
+
max_length=truncate_prompt_tokens)
|
201 |
+
|
202 |
+
input_ids = encoded.input_ids
|
203 |
+
|
204 |
+
input_text = prompt
|
205 |
+
|
206 |
+
return self._validate_input(request, input_ids, input_text)
|
207 |
+
|
208 |
+
def _normalize_prompt_tokens_to_input(
|
209 |
+
self,
|
210 |
+
request: AnyRequest,
|
211 |
+
tokenizer: AnyTokenizer,
|
212 |
+
prompt_ids: List[int],
|
213 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
214 |
+
) -> TextTokensPrompt:
|
215 |
+
if truncate_prompt_tokens is None:
|
216 |
input_ids = prompt_ids
|
217 |
+
else:
|
218 |
+
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
219 |
+
|
220 |
+
input_text = tokenizer.decode(input_ids)
|
221 |
|
222 |
+
return self._validate_input(request, input_ids, input_text)
|
223 |
+
|
224 |
+
def _validate_input(
|
225 |
+
self,
|
226 |
+
request: AnyRequest,
|
227 |
+
input_ids: List[int],
|
228 |
+
input_text: str,
|
229 |
+
) -> TextTokensPrompt:
|
230 |
token_num = len(input_ids)
|
231 |
|
232 |
# Note: EmbeddingRequest doesn't have max_tokens
|
|
|
236 |
f"This model's maximum context length is "
|
237 |
f"{self.max_model_len} tokens. However, you requested "
|
238 |
f"{token_num} tokens in the input for embedding "
|
239 |
+
f"generation. Please reduce the length of the input.")
|
240 |
+
return TextTokensPrompt(prompt=input_text,
|
241 |
+
prompt_token_ids=input_ids)
|
242 |
|
243 |
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
244 |
# and does not require model context length validation
|
245 |
+
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
246 |
+
DetokenizeRequest)):
|
247 |
+
return TextTokensPrompt(prompt=input_text,
|
248 |
+
prompt_token_ids=input_ids)
|
249 |
|
250 |
if request.max_tokens is None:
|
251 |
if token_num >= self.max_model_len:
|
|
|
253 |
f"This model's maximum context length is "
|
254 |
f"{self.max_model_len} tokens. However, you requested "
|
255 |
f"{token_num} tokens in the messages, "
|
256 |
+
f"Please reduce the length of the messages.")
|
257 |
request.max_tokens = self.max_model_len - token_num
|
258 |
|
259 |
if token_num + request.max_tokens > self.max_model_len:
|
|
|
263 |
f"{request.max_tokens + token_num} tokens "
|
264 |
f"({token_num} in the messages, "
|
265 |
f"{request.max_tokens} in the completion). "
|
266 |
+
f"Please reduce the length of the messages or completion.")
|
267 |
+
|
268 |
+
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
269 |
+
|
270 |
+
def _tokenize_prompt_input(
|
271 |
+
self,
|
272 |
+
request: AnyRequest,
|
273 |
+
tokenizer: AnyTokenizer,
|
274 |
+
prompt_input: Union[str, List[int]],
|
275 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
276 |
+
add_special_tokens: bool = True,
|
277 |
+
) -> TextTokensPrompt:
|
278 |
+
"""
|
279 |
+
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
280 |
+
that assumes single input.
|
281 |
+
"""
|
282 |
+
return next(
|
283 |
+
self._tokenize_prompt_inputs(
|
284 |
+
request,
|
285 |
+
tokenizer,
|
286 |
+
[prompt_input],
|
287 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
288 |
+
add_special_tokens=add_special_tokens,
|
289 |
+
))
|
290 |
+
|
291 |
+
def _tokenize_prompt_inputs(
|
292 |
+
self,
|
293 |
+
request: AnyRequest,
|
294 |
+
tokenizer: AnyTokenizer,
|
295 |
+
prompt_inputs: Iterable[Union[str, List[int]]],
|
296 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
297 |
+
add_special_tokens: bool = True,
|
298 |
+
) -> Iterator[TextTokensPrompt]:
|
299 |
+
"""
|
300 |
+
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
301 |
+
that assumes multiple inputs.
|
302 |
+
"""
|
303 |
+
for text in prompt_inputs:
|
304 |
+
if isinstance(text, str):
|
305 |
+
yield self._normalize_prompt_text_to_input(
|
306 |
+
request,
|
307 |
+
tokenizer,
|
308 |
+
prompt=text,
|
309 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
310 |
+
add_special_tokens=add_special_tokens,
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
yield self._normalize_prompt_tokens_to_input(
|
314 |
+
request,
|
315 |
+
tokenizer,
|
316 |
+
prompt_ids=text,
|
317 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
318 |
+
)
|
319 |
+
|
320 |
+
def _tokenize_prompt_input_or_inputs(
|
321 |
+
self,
|
322 |
+
request: AnyRequest,
|
323 |
+
tokenizer: AnyTokenizer,
|
324 |
+
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
325 |
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
326 |
+
add_special_tokens: bool = True,
|
327 |
+
) -> Iterator[TextTokensPrompt]:
|
328 |
+
"""
|
329 |
+
Tokenize/detokenize depending on the input format.
|
330 |
+
|
331 |
+
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
332 |
+
, each input can be a string or array of tokens. Note that each request
|
333 |
+
can pass one or more inputs.
|
334 |
+
"""
|
335 |
+
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
336 |
+
# Although our type checking is based on mypy,
|
337 |
+
# VSCode Pyright extension should still work properly
|
338 |
+
# "is True" is required for Pyright to perform type narrowing
|
339 |
+
# See: https://github.com/microsoft/pyright/issues/7672
|
340 |
+
if prompt_input["is_tokens"] is False:
|
341 |
+
yield self._normalize_prompt_text_to_input(
|
342 |
+
request,
|
343 |
+
tokenizer,
|
344 |
+
prompt=prompt_input["content"],
|
345 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
346 |
+
add_special_tokens=add_special_tokens,
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
yield self._normalize_prompt_tokens_to_input(
|
350 |
+
request,
|
351 |
+
tokenizer,
|
352 |
+
prompt_ids=prompt_input["content"],
|
353 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
354 |
+
)
|
355 |
+
|
356 |
+
def _log_inputs(
|
357 |
+
self,
|
358 |
+
request_id: str,
|
359 |
+
inputs: Union[str, List[int], TextTokensPrompt],
|
360 |
+
params: Optional[Union[SamplingParams, PoolingParams]],
|
361 |
+
lora_request: Optional[LoRARequest],
|
362 |
+
prompt_adapter_request: Optional[PromptAdapterRequest],
|
363 |
+
) -> None:
|
364 |
+
if self.request_logger is None:
|
365 |
+
return
|
366 |
+
|
367 |
+
if isinstance(inputs, str):
|
368 |
+
prompt = inputs
|
369 |
+
prompt_token_ids = None
|
370 |
+
elif isinstance(inputs, list):
|
371 |
+
prompt = None
|
372 |
+
prompt_token_ids = inputs
|
373 |
else:
|
374 |
+
prompt = inputs["prompt"]
|
375 |
+
prompt_token_ids = inputs["prompt_token_ids"]
|
376 |
+
|
377 |
+
self.request_logger.log_inputs(
|
378 |
+
request_id,
|
379 |
+
prompt,
|
380 |
+
prompt_token_ids,
|
381 |
+
params=params,
|
382 |
+
lora_request=lora_request,
|
383 |
+
prompt_adapter_request=prompt_adapter_request,
|
384 |
+
)
|
385 |
|
386 |
+
@staticmethod
|
387 |
+
def _get_decoded_token(
|
388 |
+
logprob: Logprob,
|
389 |
+
token_id: int,
|
390 |
+
tokenizer: AnyTokenizer,
|
391 |
+
) -> str:
|
392 |
if logprob.decoded_token is not None:
|
393 |
return logprob.decoded_token
|
394 |
+
return tokenizer.decode(token_id)
|
serving_tokenization.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
from vllm.config import ModelConfig
|
4 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
5 |
+
# yapf conflicts with isort for this block
|
6 |
+
# yapf: disable
|
7 |
+
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
8 |
+
load_chat_template,
|
9 |
+
parse_chat_message_content)
|
10 |
+
from vllm.entrypoints.logger import RequestLogger
|
11 |
+
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
12 |
+
DetokenizeResponse,
|
13 |
+
ErrorResponse,
|
14 |
+
TokenizeChatRequest,
|
15 |
+
TokenizeRequest,
|
16 |
+
TokenizeResponse)
|
17 |
+
# yapf: enable
|
18 |
+
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
19 |
+
OpenAIServing)
|
20 |
+
from vllm.utils import random_uuid
|
21 |
+
|
22 |
+
|
23 |
+
class OpenAIServingTokenization(OpenAIServing):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
engine: AsyncLLMEngine,
|
28 |
+
model_config: ModelConfig,
|
29 |
+
served_model_names: List[str],
|
30 |
+
*,
|
31 |
+
lora_modules: Optional[List[LoRAModulePath]],
|
32 |
+
request_logger: Optional[RequestLogger],
|
33 |
+
chat_template: Optional[str],
|
34 |
+
):
|
35 |
+
super().__init__(engine=engine,
|
36 |
+
model_config=model_config,
|
37 |
+
served_model_names=served_model_names,
|
38 |
+
lora_modules=lora_modules,
|
39 |
+
prompt_adapters=None,
|
40 |
+
request_logger=request_logger)
|
41 |
+
|
42 |
+
# If this is None we use the tokenizer's default chat template
|
43 |
+
self.chat_template = load_chat_template(chat_template)
|
44 |
+
|
45 |
+
async def create_tokenize(
|
46 |
+
self,
|
47 |
+
request: TokenizeRequest,
|
48 |
+
) -> Union[TokenizeResponse, ErrorResponse]:
|
49 |
+
error_check_ret = await self._check_model(request)
|
50 |
+
if error_check_ret is not None:
|
51 |
+
return error_check_ret
|
52 |
+
|
53 |
+
request_id = f"tokn-{random_uuid()}"
|
54 |
+
|
55 |
+
(
|
56 |
+
lora_request,
|
57 |
+
prompt_adapter_request,
|
58 |
+
) = self._maybe_get_adapters(request)
|
59 |
+
|
60 |
+
tokenizer = await self.engine.get_tokenizer(lora_request)
|
61 |
+
|
62 |
+
if isinstance(request, TokenizeChatRequest):
|
63 |
+
model_config = self.model_config
|
64 |
+
|
65 |
+
conversation: List[ConversationMessage] = []
|
66 |
+
|
67 |
+
for message in request.messages:
|
68 |
+
result = parse_chat_message_content(message, model_config,
|
69 |
+
tokenizer)
|
70 |
+
conversation.extend(result.messages)
|
71 |
+
|
72 |
+
prompt = tokenizer.apply_chat_template(
|
73 |
+
add_generation_prompt=request.add_generation_prompt,
|
74 |
+
conversation=conversation,
|
75 |
+
tokenize=False,
|
76 |
+
chat_template=self.chat_template)
|
77 |
+
assert isinstance(prompt, str)
|
78 |
+
else:
|
79 |
+
prompt = request.prompt
|
80 |
+
|
81 |
+
self._log_inputs(request_id,
|
82 |
+
prompt,
|
83 |
+
params=None,
|
84 |
+
lora_request=lora_request,
|
85 |
+
prompt_adapter_request=prompt_adapter_request)
|
86 |
+
|
87 |
+
# Silently ignore prompt adapter since it does not affect tokenization
|
88 |
+
|
89 |
+
prompt_input = self._tokenize_prompt_input(
|
90 |
+
request,
|
91 |
+
tokenizer,
|
92 |
+
prompt,
|
93 |
+
add_special_tokens=request.add_special_tokens,
|
94 |
+
)
|
95 |
+
input_ids = prompt_input["prompt_token_ids"]
|
96 |
+
|
97 |
+
return TokenizeResponse(tokens=input_ids,
|
98 |
+
count=len(input_ids),
|
99 |
+
max_model_len=self.max_model_len)
|
100 |
+
|
101 |
+
async def create_detokenize(
|
102 |
+
self,
|
103 |
+
request: DetokenizeRequest,
|
104 |
+
) -> Union[DetokenizeResponse, ErrorResponse]:
|
105 |
+
error_check_ret = await self._check_model(request)
|
106 |
+
if error_check_ret is not None:
|
107 |
+
return error_check_ret
|
108 |
+
|
109 |
+
request_id = f"tokn-{random_uuid()}"
|
110 |
+
|
111 |
+
(
|
112 |
+
lora_request,
|
113 |
+
prompt_adapter_request,
|
114 |
+
) = self._maybe_get_adapters(request)
|
115 |
+
|
116 |
+
tokenizer = await self.engine.get_tokenizer(lora_request)
|
117 |
+
|
118 |
+
self._log_inputs(request_id,
|
119 |
+
request.tokens,
|
120 |
+
params=None,
|
121 |
+
lora_request=lora_request,
|
122 |
+
prompt_adapter_request=prompt_adapter_request)
|
123 |
+
|
124 |
+
if prompt_adapter_request is not None:
|
125 |
+
raise NotImplementedError("Prompt adapter is not supported "
|
126 |
+
"for tokenization")
|
127 |
+
|
128 |
+
prompt_input = self._tokenize_prompt_input(
|
129 |
+
request,
|
130 |
+
tokenizer,
|
131 |
+
request.tokens,
|
132 |
+
)
|
133 |
+
input_text = prompt_input["prompt"]
|
134 |
+
|
135 |
+
return DetokenizeResponse(prompt=input_text)
|