sofianhw commited on
Commit
5491dc5
1 Parent(s): 8f99309

change to 0.5.3 support llama 3.1

Browse files
Files changed (8) hide show
  1. Dockerfile +1 -1
  2. api_server.py +109 -43
  3. protocol.py +61 -85
  4. serving_chat.py +119 -220
  5. serving_completion.py +90 -98
  6. serving_embedding.py +54 -25
  7. serving_engine.py +259 -74
  8. serving_tokenization.py +135 -0
Dockerfile CHANGED
@@ -14,7 +14,7 @@ RUN pip3 install "torch==2.1.1"
14
  # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
15
  # RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
16
  RUN pip3 install -U openai
17
- RUN pip3 install vllm==0.5.1
18
  RUN pip3 install -U pydantic
19
  RUN pip3 install -U aioprometheus
20
 
 
14
  # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
15
  # RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
16
  RUN pip3 install -U openai
17
+ RUN pip3 install vllm==0.5.3
18
  RUN pip3 install -U pydantic
19
  RUN pip3 install -U aioprometheus
20
 
api_server.py CHANGED
@@ -8,7 +8,7 @@ from typing import Optional, Set
8
 
9
  import fastapi
10
  import uvicorn
11
- from fastapi import Request
12
  from fastapi.exceptions import RequestValidationError
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -18,6 +18,7 @@ from starlette.routing import Mount
18
  import vllm.envs as envs
19
  from vllm.engine.arg_utils import AsyncEngineArgs
20
  from vllm.engine.async_llm_engine import AsyncLLMEngine
 
21
  from vllm.entrypoints.openai.cli_args import make_arg_parser
22
  # yapf conflicts with isort for this block
23
  # yapf: disable
@@ -33,15 +34,21 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
33
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
34
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
35
  from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
 
 
36
  from vllm.logger import init_logger
37
  from vllm.usage.usage_lib import UsageContext
 
38
  from vllm.version import __version__ as VLLM_VERSION
39
 
40
  TIMEOUT_KEEP_ALIVE = 5 # seconds
41
 
 
 
42
  openai_serving_chat: OpenAIServingChat
43
  openai_serving_completion: OpenAIServingCompletion
44
  openai_serving_embedding: OpenAIServingEmbedding
 
45
 
46
  logger = init_logger('vllm.entrypoints.openai.api_server')
47
 
@@ -64,37 +71,27 @@ async def lifespan(app: fastapi.FastAPI):
64
  yield
65
 
66
 
67
- app = fastapi.FastAPI(lifespan=lifespan)
68
 
69
 
70
- def parse_args():
71
- parser = make_arg_parser()
72
- return parser.parse_args()
 
 
 
73
 
74
 
75
- # Add prometheus asgi middleware to route /metrics requests
76
- route = Mount("/metrics", make_asgi_app())
77
- # Workaround for 307 Redirect for /metrics
78
- route.path_regex = re.compile('^/metrics(?P<path>.*)$')
79
- app.routes.append(route)
80
-
81
-
82
- @app.exception_handler(RequestValidationError)
83
- async def validation_exception_handler(_, exc):
84
- err = openai_serving_chat.create_error_response(message=str(exc))
85
- return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
86
-
87
-
88
- @app.get("/health")
89
  async def health() -> Response:
90
  """Health check."""
91
  await openai_serving_chat.engine.check_health()
92
  return Response(status_code=200)
93
 
94
 
95
- @app.post("/tokenize")
96
  async def tokenize(request: TokenizeRequest):
97
- generator = await openai_serving_completion.create_tokenize(request)
98
  if isinstance(generator, ErrorResponse):
99
  return JSONResponse(content=generator.model_dump(),
100
  status_code=generator.code)
@@ -103,9 +100,9 @@ async def tokenize(request: TokenizeRequest):
103
  return JSONResponse(content=generator.model_dump())
104
 
105
 
106
- @app.post("/detokenize")
107
  async def detokenize(request: DetokenizeRequest):
108
- generator = await openai_serving_completion.create_detokenize(request)
109
  if isinstance(generator, ErrorResponse):
110
  return JSONResponse(content=generator.model_dump(),
111
  status_code=generator.code)
@@ -114,19 +111,19 @@ async def detokenize(request: DetokenizeRequest):
114
  return JSONResponse(content=generator.model_dump())
115
 
116
 
117
- @app.get("/v1/models")
118
  async def show_available_models():
119
- models = await openai_serving_chat.show_available_models()
120
  return JSONResponse(content=models.model_dump())
121
 
122
 
123
- @app.get("/version")
124
  async def show_version():
125
  ver = {"version": VLLM_VERSION}
126
  return JSONResponse(content=ver)
127
 
128
 
129
- @app.post("/v1/chat/completions")
130
  async def create_chat_completion(request: ChatCompletionRequest,
131
  raw_request: Request):
132
  generator = await openai_serving_chat.create_chat_completion(
@@ -142,7 +139,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
142
  return JSONResponse(content=generator.model_dump())
143
 
144
 
145
- @app.post("/v1/completions")
146
  async def create_completion(request: CompletionRequest, raw_request: Request):
147
  generator = await openai_serving_completion.create_completion(
148
  request, raw_request)
@@ -156,7 +153,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
156
  return JSONResponse(content=generator.model_dump())
157
 
158
 
159
- @app.post("/v1/embeddings")
160
  async def create_embedding(request: EmbeddingRequest, raw_request: Request):
161
  generator = await openai_serving_embedding.create_embedding(
162
  request, raw_request)
@@ -167,8 +164,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
167
  return JSONResponse(content=generator.model_dump())
168
 
169
 
170
- if __name__ == "__main__":
171
- args = parse_args()
 
 
 
 
172
 
173
  app.add_middleware(
174
  CORSMiddleware,
@@ -178,6 +179,12 @@ if __name__ == "__main__":
178
  allow_headers=args.allowed_headers,
179
  )
180
 
 
 
 
 
 
 
181
  if token := envs.VLLM_API_KEY or args.api_key:
182
 
183
  @app.middleware("http")
@@ -203,6 +210,12 @@ if __name__ == "__main__":
203
  raise ValueError(f"Invalid middleware {middleware}. "
204
  f"Must be a function or a class.")
205
 
 
 
 
 
 
 
206
  logger.info("vLLM API server version %s", VLLM_VERSION)
207
  logger.info("args: %s", args)
208
 
@@ -211,10 +224,12 @@ if __name__ == "__main__":
211
  else:
212
  served_model_names = [args.model]
213
 
214
- engine_args = AsyncEngineArgs.from_cli_args(args)
215
 
216
- engine = AsyncLLMEngine.from_engine_args(
217
- engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
 
 
218
 
219
  event_loop: Optional[asyncio.AbstractEventLoop]
220
  try:
@@ -230,16 +245,57 @@ if __name__ == "__main__":
230
  # When using single vLLM without engine_use_ray
231
  model_config = asyncio.run(engine.get_model_config())
232
 
233
- openai_serving_chat = OpenAIServingChat(engine, model_config,
234
- served_model_names,
235
- args.response_role,
236
- args.lora_modules,
237
- args.chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  openai_serving_completion = OpenAIServingCompletion(
239
- engine, model_config, served_model_names, args.lora_modules)
240
- openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
241
- served_model_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  app.root_path = args.root_path
 
 
 
 
 
 
 
 
243
  uvicorn.run(app,
244
  host=args.host,
245
  port=args.port,
@@ -248,4 +304,14 @@ if __name__ == "__main__":
248
  ssl_keyfile=args.ssl_keyfile,
249
  ssl_certfile=args.ssl_certfile,
250
  ssl_ca_certs=args.ssl_ca_certs,
251
- ssl_cert_reqs=args.ssl_cert_reqs)
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import fastapi
10
  import uvicorn
11
+ from fastapi import APIRouter, Request
12
  from fastapi.exceptions import RequestValidationError
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse, Response, StreamingResponse
 
18
  import vllm.envs as envs
19
  from vllm.engine.arg_utils import AsyncEngineArgs
20
  from vllm.engine.async_llm_engine import AsyncLLMEngine
21
+ from vllm.entrypoints.logger import RequestLogger
22
  from vllm.entrypoints.openai.cli_args import make_arg_parser
23
  # yapf conflicts with isort for this block
24
  # yapf: disable
 
34
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
35
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
36
  from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
37
+ from vllm.entrypoints.openai.serving_tokenization import (
38
+ OpenAIServingTokenization)
39
  from vllm.logger import init_logger
40
  from vllm.usage.usage_lib import UsageContext
41
+ from vllm.utils import FlexibleArgumentParser
42
  from vllm.version import __version__ as VLLM_VERSION
43
 
44
  TIMEOUT_KEEP_ALIVE = 5 # seconds
45
 
46
+ engine: AsyncLLMEngine
47
+ engine_args: AsyncEngineArgs
48
  openai_serving_chat: OpenAIServingChat
49
  openai_serving_completion: OpenAIServingCompletion
50
  openai_serving_embedding: OpenAIServingEmbedding
51
+ openai_serving_tokenization: OpenAIServingTokenization
52
 
53
  logger = init_logger('vllm.entrypoints.openai.api_server')
54
 
 
71
  yield
72
 
73
 
74
+ router = APIRouter()
75
 
76
 
77
+ def mount_metrics(app: fastapi.FastAPI):
78
+ # Add prometheus asgi middleware to route /metrics requests
79
+ metrics_route = Mount("/metrics", make_asgi_app())
80
+ # Workaround for 307 Redirect for /metrics
81
+ metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
82
+ app.routes.append(metrics_route)
83
 
84
 
85
+ @router.get("/health")
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  async def health() -> Response:
87
  """Health check."""
88
  await openai_serving_chat.engine.check_health()
89
  return Response(status_code=200)
90
 
91
 
92
+ @router.post("/tokenize")
93
  async def tokenize(request: TokenizeRequest):
94
+ generator = await openai_serving_tokenization.create_tokenize(request)
95
  if isinstance(generator, ErrorResponse):
96
  return JSONResponse(content=generator.model_dump(),
97
  status_code=generator.code)
 
100
  return JSONResponse(content=generator.model_dump())
101
 
102
 
103
+ @router.post("/detokenize")
104
  async def detokenize(request: DetokenizeRequest):
105
+ generator = await openai_serving_tokenization.create_detokenize(request)
106
  if isinstance(generator, ErrorResponse):
107
  return JSONResponse(content=generator.model_dump(),
108
  status_code=generator.code)
 
111
  return JSONResponse(content=generator.model_dump())
112
 
113
 
114
+ @router.get("/v1/models")
115
  async def show_available_models():
116
+ models = await openai_serving_completion.show_available_models()
117
  return JSONResponse(content=models.model_dump())
118
 
119
 
120
+ @router.get("/version")
121
  async def show_version():
122
  ver = {"version": VLLM_VERSION}
123
  return JSONResponse(content=ver)
124
 
125
 
126
+ @router.post("/v1/chat/completions")
127
  async def create_chat_completion(request: ChatCompletionRequest,
128
  raw_request: Request):
129
  generator = await openai_serving_chat.create_chat_completion(
 
139
  return JSONResponse(content=generator.model_dump())
140
 
141
 
142
+ @router.post("/v1/completions")
143
  async def create_completion(request: CompletionRequest, raw_request: Request):
144
  generator = await openai_serving_completion.create_completion(
145
  request, raw_request)
 
153
  return JSONResponse(content=generator.model_dump())
154
 
155
 
156
+ @router.post("/v1/embeddings")
157
  async def create_embedding(request: EmbeddingRequest, raw_request: Request):
158
  generator = await openai_serving_embedding.create_embedding(
159
  request, raw_request)
 
164
  return JSONResponse(content=generator.model_dump())
165
 
166
 
167
+ def build_app(args):
168
+ app = fastapi.FastAPI(lifespan=lifespan)
169
+ app.include_router(router)
170
+ app.root_path = args.root_path
171
+
172
+ mount_metrics(app)
173
 
174
  app.add_middleware(
175
  CORSMiddleware,
 
179
  allow_headers=args.allowed_headers,
180
  )
181
 
182
+ @app.exception_handler(RequestValidationError)
183
+ async def validation_exception_handler(_, exc):
184
+ err = openai_serving_chat.create_error_response(message=str(exc))
185
+ return JSONResponse(err.model_dump(),
186
+ status_code=HTTPStatus.BAD_REQUEST)
187
+
188
  if token := envs.VLLM_API_KEY or args.api_key:
189
 
190
  @app.middleware("http")
 
210
  raise ValueError(f"Invalid middleware {middleware}. "
211
  f"Must be a function or a class.")
212
 
213
+ return app
214
+
215
+
216
+ def run_server(args, llm_engine=None):
217
+ app = build_app(args)
218
+
219
  logger.info("vLLM API server version %s", VLLM_VERSION)
220
  logger.info("args: %s", args)
221
 
 
224
  else:
225
  served_model_names = [args.model]
226
 
227
+ global engine, engine_args
228
 
229
+ engine_args = AsyncEngineArgs.from_cli_args(args)
230
+ engine = (llm_engine
231
+ if llm_engine is not None else AsyncLLMEngine.from_engine_args(
232
+ engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
233
 
234
  event_loop: Optional[asyncio.AbstractEventLoop]
235
  try:
 
245
  # When using single vLLM without engine_use_ray
246
  model_config = asyncio.run(engine.get_model_config())
247
 
248
+ if args.disable_log_requests:
249
+ request_logger = None
250
+ else:
251
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
252
+
253
+ global openai_serving_chat
254
+ global openai_serving_completion
255
+ global openai_serving_embedding
256
+ global openai_serving_tokenization
257
+
258
+ openai_serving_chat = OpenAIServingChat(
259
+ engine,
260
+ model_config,
261
+ served_model_names,
262
+ args.response_role,
263
+ lora_modules=args.lora_modules,
264
+ prompt_adapters=args.prompt_adapters,
265
+ request_logger=request_logger,
266
+ chat_template=args.chat_template,
267
+ )
268
  openai_serving_completion = OpenAIServingCompletion(
269
+ engine,
270
+ model_config,
271
+ served_model_names,
272
+ lora_modules=args.lora_modules,
273
+ prompt_adapters=args.prompt_adapters,
274
+ request_logger=request_logger,
275
+ )
276
+ openai_serving_embedding = OpenAIServingEmbedding(
277
+ engine,
278
+ model_config,
279
+ served_model_names,
280
+ request_logger=request_logger,
281
+ )
282
+ openai_serving_tokenization = OpenAIServingTokenization(
283
+ engine,
284
+ model_config,
285
+ served_model_names,
286
+ lora_modules=args.lora_modules,
287
+ request_logger=request_logger,
288
+ chat_template=args.chat_template,
289
+ )
290
  app.root_path = args.root_path
291
+
292
+ logger.info("Available routes are:")
293
+ for route in app.routes:
294
+ if not hasattr(route, 'methods'):
295
+ continue
296
+ methods = ', '.join(route.methods)
297
+ logger.info("Route: %s, Methods: %s", route.path, methods)
298
+
299
  uvicorn.run(app,
300
  host=args.host,
301
  port=args.port,
 
304
  ssl_keyfile=args.ssl_keyfile,
305
  ssl_certfile=args.ssl_certfile,
306
  ssl_ca_certs=args.ssl_ca_certs,
307
+ ssl_cert_reqs=args.ssl_cert_reqs)
308
+
309
+
310
+ if __name__ == "__main__":
311
+ # NOTE(simon):
312
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
313
+ parser = FlexibleArgumentParser(
314
+ description="vLLM OpenAI-Compatible RESTful API server.")
315
+ parser = make_arg_parser(parser)
316
+ args = parser.parse_args()
317
+ run_server(args)
protocol.py CHANGED
@@ -3,50 +3,16 @@
3
  import time
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
- import openai.types.chat
7
  import torch
8
  from pydantic import BaseModel, ConfigDict, Field, model_validator
9
- # pydantic needs the TypedDict from typing_extensions
10
- from typing_extensions import Annotated, Required, TypedDict
11
 
 
12
  from vllm.pooling_params import PoolingParams
13
  from vllm.sampling_params import SamplingParams
14
  from vllm.utils import random_uuid
15
 
16
 
17
- class CustomChatCompletionContentPartParam(TypedDict, total=False):
18
- __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
19
-
20
- type: Required[str]
21
- """The type of the content part."""
22
-
23
-
24
- ChatCompletionContentPartParam = Union[
25
- openai.types.chat.ChatCompletionContentPartParam,
26
- CustomChatCompletionContentPartParam]
27
-
28
-
29
- class CustomChatCompletionMessageParam(TypedDict, total=False):
30
- """Enables custom roles in the Chat Completion API."""
31
- role: Required[str]
32
- """The role of the message's author."""
33
-
34
- content: Union[str, List[ChatCompletionContentPartParam]]
35
- """The contents of the message."""
36
-
37
- name: str
38
- """An optional name for the participant.
39
-
40
- Provides the model information to differentiate between participants of the
41
- same role.
42
- """
43
-
44
-
45
- ChatCompletionMessageParam = Union[
46
- openai.types.chat.ChatCompletionMessageParam,
47
- CustomChatCompletionMessageParam]
48
-
49
-
50
  class OpenAIBaseModel(BaseModel):
51
  # OpenAI API does not allow extra fields
52
  model_config = ConfigDict(extra="forbid")
@@ -155,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
155
 
156
  # doc: begin-chat-completion-sampling-params
157
  best_of: Optional[int] = None
158
- use_beam_search: Optional[bool] = False
159
- top_k: Optional[int] = -1
160
- min_p: Optional[float] = 0.0
161
- repetition_penalty: Optional[float] = 1.0
162
- length_penalty: Optional[float] = 1.0
163
- early_stopping: Optional[bool] = False
164
- ignore_eos: Optional[bool] = False
165
- min_tokens: Optional[int] = 0
166
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
167
- skip_special_tokens: Optional[bool] = True
168
- spaces_between_special_tokens: Optional[bool] = True
 
 
 
 
169
  # doc: end-chat-completion-sampling-params
170
 
171
  # doc: begin-chat-completion-extra-params
172
- echo: Optional[bool] = Field(
173
  default=False,
174
  description=(
175
  "If true, the new message will be prepended with the last message "
176
  "if they belong to the same role."),
177
  )
178
- add_generation_prompt: Optional[bool] = Field(
179
  default=True,
180
  description=
181
  ("If true, the generation prompt will be added to the chat template. "
182
  "This is a parameter used by chat template in tokenizer config of the "
183
  "model."),
184
  )
185
- add_special_tokens: Optional[bool] = Field(
186
  default=False,
187
  description=(
188
  "If true, special tokens (e.g. BOS) will be added to the prompt "
189
  "on top of what is added by the chat template. "
190
  "For most models, the chat template takes care of adding the "
191
- "special tokens so this should be set to False (as is the "
192
  "default)."),
193
  )
194
  documents: Optional[List[Dict[str, str]]] = Field(
@@ -212,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
212
  description=("Additional kwargs to pass to the template renderer. "
213
  "Will be accessible by the chat template."),
214
  )
215
- include_stop_str_in_output: Optional[bool] = Field(
216
- default=False,
217
- description=(
218
- "Whether to include the stop string in the output. "
219
- "This is only applied when the stop or stop_token_ids is set."),
220
- )
221
  guided_json: Optional[Union[str, dict, BaseModel]] = Field(
222
  default=None,
223
  description=("If specified, the output will follow the JSON schema."),
@@ -278,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
278
 
279
  return SamplingParams(
280
  n=self.n,
 
281
  presence_penalty=self.presence_penalty,
282
  frequency_penalty=self.frequency_penalty,
283
  repetition_penalty=self.repetition_penalty,
284
  temperature=self.temperature,
285
  top_p=self.top_p,
 
286
  min_p=self.min_p,
287
  seed=self.seed,
288
  stop=self.stop,
289
  stop_token_ids=self.stop_token_ids,
290
- max_tokens=self.max_tokens,
291
- min_tokens=self.min_tokens,
292
  logprobs=self.top_logprobs if self.logprobs else None,
293
  prompt_logprobs=self.top_logprobs if self.echo else None,
294
- best_of=self.best_of,
295
- top_k=self.top_k,
296
  ignore_eos=self.ignore_eos,
 
 
297
  use_beam_search=self.use_beam_search,
298
  early_stopping=self.early_stopping,
299
  skip_special_tokens=self.skip_special_tokens,
@@ -301,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
301
  include_stop_str_in_output=self.include_stop_str_in_output,
302
  length_penalty=self.length_penalty,
303
  logits_processors=logits_processors,
 
304
  )
305
 
306
  @model_validator(mode='before')
@@ -382,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
382
  user: Optional[str] = None
383
 
384
  # doc: begin-completion-sampling-params
385
- use_beam_search: Optional[bool] = False
386
- top_k: Optional[int] = -1
387
- min_p: Optional[float] = 0.0
388
- repetition_penalty: Optional[float] = 1.0
389
- length_penalty: Optional[float] = 1.0
390
- early_stopping: Optional[bool] = False
391
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
392
- ignore_eos: Optional[bool] = False
393
- min_tokens: Optional[int] = 0
394
- skip_special_tokens: Optional[bool] = True
395
- spaces_between_special_tokens: Optional[bool] = True
 
396
  truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
397
  # doc: end-completion-sampling-params
398
 
399
  # doc: begin-completion-extra-params
400
- include_stop_str_in_output: Optional[bool] = Field(
401
- default=False,
402
  description=(
403
- "Whether to include the stop string in the output. "
404
- "This is only applied when the stop or stop_token_ids is set."),
405
  )
406
  response_format: Optional[ResponseFormat] = Field(
407
  default=None,
@@ -481,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
481
  seed=self.seed,
482
  stop=self.stop,
483
  stop_token_ids=self.stop_token_ids,
 
484
  ignore_eos=self.ignore_eos,
485
  max_tokens=self.max_tokens if not echo_without_generation else 1,
486
  min_tokens=self.min_tokens,
487
- logprobs=self.logprobs,
488
  use_beam_search=self.use_beam_search,
489
  early_stopping=self.early_stopping,
490
  prompt_logprobs=self.logprobs if self.echo else None,
491
  skip_special_tokens=self.skip_special_tokens,
492
- spaces_between_special_tokens=(self.spaces_between_special_tokens),
493
  include_stop_str_in_output=self.include_stop_str_in_output,
494
  length_penalty=self.length_penalty,
495
  logits_processors=logits_processors,
@@ -523,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
523
  def validate_stream_options(cls, data):
524
  if data.get("stream_options") and not data.get("stream"):
525
  raise ValueError(
526
- "Stream options can only be defined when stream is True.")
527
  return data
528
 
529
 
530
- class EmbeddingRequest(BaseModel):
531
  # Ordered by official OpenAI API documentation
532
  # https://platform.openai.com/docs/api-reference/embeddings
533
  model: str
@@ -599,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
599
  usage: Optional[UsageInfo] = Field(default=None)
600
 
601
 
602
- class EmbeddingResponseData(BaseModel):
603
  index: int
604
  object: str = "embedding"
605
  embedding: Union[List[float], str]
606
 
607
 
608
- class EmbeddingResponse(BaseModel):
609
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
610
  object: str = "list"
611
  created: int = Field(default_factory=lambda: int(time.time()))
@@ -704,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
704
  # /v1/chat/completions is supported.
705
  url: str
706
 
707
- # The parameteters of the request.
708
- body: Union[ChatCompletionRequest, ]
709
 
710
 
711
  class BatchResponseData(OpenAIBaseModel):
@@ -716,7 +680,7 @@ class BatchResponseData(OpenAIBaseModel):
716
  request_id: str
717
 
718
  # The body of the response.
719
- body: Union[ChatCompletionResponse, ]
720
 
721
 
722
  class BatchRequestOutput(OpenAIBaseModel):
@@ -737,16 +701,28 @@ class BatchRequestOutput(OpenAIBaseModel):
737
  error: Optional[Any]
738
 
739
 
740
- class TokenizeRequest(OpenAIBaseModel):
741
  model: str
742
  prompt: str
 
743
  add_special_tokens: bool = Field(default=True)
744
 
745
 
 
 
 
 
 
 
 
 
 
 
 
746
  class TokenizeResponse(OpenAIBaseModel):
747
- tokens: List[int]
748
  count: int
749
  max_model_len: int
 
750
 
751
 
752
  class DetokenizeRequest(OpenAIBaseModel):
 
3
  import time
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
 
6
  import torch
7
  from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from typing_extensions import Annotated
 
9
 
10
+ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
11
  from vllm.pooling_params import PoolingParams
12
  from vllm.sampling_params import SamplingParams
13
  from vllm.utils import random_uuid
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class OpenAIBaseModel(BaseModel):
17
  # OpenAI API does not allow extra fields
18
  model_config = ConfigDict(extra="forbid")
 
121
 
122
  # doc: begin-chat-completion-sampling-params
123
  best_of: Optional[int] = None
124
+ use_beam_search: bool = False
125
+ top_k: int = -1
126
+ min_p: float = 0.0
127
+ repetition_penalty: float = 1.0
128
+ length_penalty: float = 1.0
129
+ early_stopping: bool = False
 
 
130
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
131
+ include_stop_str_in_output: bool = False
132
+ ignore_eos: bool = False
133
+ min_tokens: int = 0
134
+ skip_special_tokens: bool = True
135
+ spaces_between_special_tokens: bool = True
136
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
137
  # doc: end-chat-completion-sampling-params
138
 
139
  # doc: begin-chat-completion-extra-params
140
+ echo: bool = Field(
141
  default=False,
142
  description=(
143
  "If true, the new message will be prepended with the last message "
144
  "if they belong to the same role."),
145
  )
146
+ add_generation_prompt: bool = Field(
147
  default=True,
148
  description=
149
  ("If true, the generation prompt will be added to the chat template. "
150
  "This is a parameter used by chat template in tokenizer config of the "
151
  "model."),
152
  )
153
+ add_special_tokens: bool = Field(
154
  default=False,
155
  description=(
156
  "If true, special tokens (e.g. BOS) will be added to the prompt "
157
  "on top of what is added by the chat template. "
158
  "For most models, the chat template takes care of adding the "
159
+ "special tokens so this should be set to false (as is the "
160
  "default)."),
161
  )
162
  documents: Optional[List[Dict[str, str]]] = Field(
 
180
  description=("Additional kwargs to pass to the template renderer. "
181
  "Will be accessible by the chat template."),
182
  )
 
 
 
 
 
 
183
  guided_json: Optional[Union[str, dict, BaseModel]] = Field(
184
  default=None,
185
  description=("If specified, the output will follow the JSON schema."),
 
240
 
241
  return SamplingParams(
242
  n=self.n,
243
+ best_of=self.best_of,
244
  presence_penalty=self.presence_penalty,
245
  frequency_penalty=self.frequency_penalty,
246
  repetition_penalty=self.repetition_penalty,
247
  temperature=self.temperature,
248
  top_p=self.top_p,
249
+ top_k=self.top_k,
250
  min_p=self.min_p,
251
  seed=self.seed,
252
  stop=self.stop,
253
  stop_token_ids=self.stop_token_ids,
 
 
254
  logprobs=self.top_logprobs if self.logprobs else None,
255
  prompt_logprobs=self.top_logprobs if self.echo else None,
 
 
256
  ignore_eos=self.ignore_eos,
257
+ max_tokens=self.max_tokens,
258
+ min_tokens=self.min_tokens,
259
  use_beam_search=self.use_beam_search,
260
  early_stopping=self.early_stopping,
261
  skip_special_tokens=self.skip_special_tokens,
 
263
  include_stop_str_in_output=self.include_stop_str_in_output,
264
  length_penalty=self.length_penalty,
265
  logits_processors=logits_processors,
266
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
267
  )
268
 
269
  @model_validator(mode='before')
 
345
  user: Optional[str] = None
346
 
347
  # doc: begin-completion-sampling-params
348
+ use_beam_search: bool = False
349
+ top_k: int = -1
350
+ min_p: float = 0.0
351
+ repetition_penalty: float = 1.0
352
+ length_penalty: float = 1.0
353
+ early_stopping: bool = False
354
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
355
+ include_stop_str_in_output: bool = False
356
+ ignore_eos: bool = False
357
+ min_tokens: int = 0
358
+ skip_special_tokens: bool = True
359
+ spaces_between_special_tokens: bool = True
360
  truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
361
  # doc: end-completion-sampling-params
362
 
363
  # doc: begin-completion-extra-params
364
+ add_special_tokens: bool = Field(
365
+ default=True,
366
  description=(
367
+ "If true (the default), special tokens (e.g. BOS) will be added to "
368
+ "the prompt."),
369
  )
370
  response_format: Optional[ResponseFormat] = Field(
371
  default=None,
 
445
  seed=self.seed,
446
  stop=self.stop,
447
  stop_token_ids=self.stop_token_ids,
448
+ logprobs=self.logprobs,
449
  ignore_eos=self.ignore_eos,
450
  max_tokens=self.max_tokens if not echo_without_generation else 1,
451
  min_tokens=self.min_tokens,
 
452
  use_beam_search=self.use_beam_search,
453
  early_stopping=self.early_stopping,
454
  prompt_logprobs=self.logprobs if self.echo else None,
455
  skip_special_tokens=self.skip_special_tokens,
456
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
457
  include_stop_str_in_output=self.include_stop_str_in_output,
458
  length_penalty=self.length_penalty,
459
  logits_processors=logits_processors,
 
487
  def validate_stream_options(cls, data):
488
  if data.get("stream_options") and not data.get("stream"):
489
  raise ValueError(
490
+ "Stream options can only be defined when stream is true.")
491
  return data
492
 
493
 
494
+ class EmbeddingRequest(OpenAIBaseModel):
495
  # Ordered by official OpenAI API documentation
496
  # https://platform.openai.com/docs/api-reference/embeddings
497
  model: str
 
563
  usage: Optional[UsageInfo] = Field(default=None)
564
 
565
 
566
+ class EmbeddingResponseData(OpenAIBaseModel):
567
  index: int
568
  object: str = "embedding"
569
  embedding: Union[List[float], str]
570
 
571
 
572
+ class EmbeddingResponse(OpenAIBaseModel):
573
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
574
  object: str = "list"
575
  created: int = Field(default_factory=lambda: int(time.time()))
 
668
  # /v1/chat/completions is supported.
669
  url: str
670
 
671
+ # The parameters of the request.
672
+ body: ChatCompletionRequest
673
 
674
 
675
  class BatchResponseData(OpenAIBaseModel):
 
680
  request_id: str
681
 
682
  # The body of the response.
683
+ body: Optional[ChatCompletionResponse] = None
684
 
685
 
686
  class BatchRequestOutput(OpenAIBaseModel):
 
701
  error: Optional[Any]
702
 
703
 
704
+ class TokenizeCompletionRequest(OpenAIBaseModel):
705
  model: str
706
  prompt: str
707
+
708
  add_special_tokens: bool = Field(default=True)
709
 
710
 
711
+ class TokenizeChatRequest(OpenAIBaseModel):
712
+ model: str
713
+ messages: List[ChatCompletionMessageParam]
714
+
715
+ add_generation_prompt: bool = Field(default=True)
716
+ add_special_tokens: bool = Field(default=False)
717
+
718
+
719
+ TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
720
+
721
+
722
  class TokenizeResponse(OpenAIBaseModel):
 
723
  count: int
724
  max_model_len: int
725
+ tokens: List[int]
726
 
727
 
728
  class DetokenizeRequest(OpenAIBaseModel):
serving_chat.py CHANGED
@@ -1,34 +1,33 @@
1
- import codecs
2
  import time
3
- from dataclasses import dataclass, field
4
- from functools import cached_property
5
- from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
6
- List, Optional)
7
  from typing import Sequence as GenericSequence
8
- from typing import TypedDict, Union, cast, final
9
 
10
  from fastapi import Request
11
- from openai.types.chat import (ChatCompletionContentPartImageParam,
12
- ChatCompletionContentPartTextParam)
13
 
14
  from vllm.config import ModelConfig
15
  from vllm.engine.async_llm_engine import AsyncLLMEngine
 
 
 
 
16
  from vllm.entrypoints.openai.protocol import (
17
- ChatCompletionContentPartParam, ChatCompletionLogProb,
18
- ChatCompletionLogProbs, ChatCompletionLogProbsContent,
19
- ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
20
  ChatCompletionRequest, ChatCompletionResponse,
21
  ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
22
  ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
23
  FunctionCall, ToolCall, UsageInfo)
24
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
25
- OpenAIServing)
 
26
  from vllm.inputs import PromptInputs
27
  from vllm.logger import init_logger
28
  from vllm.model_executor.guided_decoding import (
29
  get_guided_decoding_logits_processor)
30
  from vllm.multimodal import MultiModalDataDict
31
- from vllm.multimodal.utils import async_get_and_parse_image
32
  from vllm.outputs import RequestOutput
33
  from vllm.sequence import Logprob
34
  from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@@ -38,159 +37,31 @@ from vllm.utils import random_uuid
38
  logger = init_logger(__name__)
39
 
40
 
41
- @final # So that it should be compatible with Dict[str, str]
42
- class ConversationMessage(TypedDict):
43
- role: str
44
- content: str
45
-
46
-
47
- @dataclass(frozen=True)
48
- class ChatMessageParseResult:
49
- messages: List[ConversationMessage]
50
- mm_futures: List[Awaitable[MultiModalDataDict]] = field(
51
- default_factory=list)
52
-
53
-
54
  class OpenAIServingChat(OpenAIServing):
55
 
56
- def __init__(self,
57
- engine: AsyncLLMEngine,
58
- model_config: ModelConfig,
59
- served_model_names: List[str],
60
- response_role: str,
61
- lora_modules: Optional[List[LoRAModulePath]] = None,
62
- chat_template: Optional[str] = None):
 
 
 
 
 
63
  super().__init__(engine=engine,
64
  model_config=model_config,
65
  served_model_names=served_model_names,
66
- lora_modules=lora_modules)
 
 
67
 
68
  self.response_role = response_role
69
- self._load_chat_template(chat_template)
70
-
71
- def _load_chat_template(self, chat_template: Optional[str]):
72
- tokenizer = self.tokenizer
73
-
74
- if chat_template is not None:
75
- try:
76
- with open(chat_template, "r") as f:
77
- tokenizer.chat_template = f.read()
78
- except OSError as e:
79
- JINJA_CHARS = "{}\n"
80
- if not any(c in chat_template for c in JINJA_CHARS):
81
- msg = (f"The supplied chat template ({chat_template}) "
82
- f"looks like a file path, but it failed to be "
83
- f"opened. Reason: {e}")
84
- raise ValueError(msg) from e
85
-
86
- # If opening a file fails, set chat template to be args to
87
- # ensure we decode so our escape are interpreted correctly
88
- tokenizer.chat_template = codecs.decode(
89
- chat_template, "unicode_escape")
90
-
91
- logger.info("Using supplied chat template:\n%s",
92
- tokenizer.chat_template)
93
- elif tokenizer.chat_template is not None:
94
- logger.info("Using default chat template:\n%s",
95
- tokenizer.chat_template)
96
- else:
97
- logger.warning(
98
- "No chat template provided. Chat API will not work.")
99
-
100
- @cached_property
101
- def image_token_str(self) -> Optional[str]:
102
- # TODO: Let user specify how to insert image tokens into prompt
103
- # (similar to chat template)
104
- model_type = self.model_config.hf_config.model_type
105
- if model_type == "phi3_v":
106
- # Workaround since this token is not defined in the tokenizer
107
- return "<|image_1|>"
108
- if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
109
- "paligemma"):
110
- # These models do not use image tokens in the prompt
111
- return None
112
- if model_type.startswith("llava"):
113
- return self.tokenizer.decode(
114
- self.model_config.hf_config.image_token_index)
115
-
116
- else:
117
- raise TypeError("Unknown model type: {model_type}")
118
-
119
- # TODO: Let user specify how to insert image tokens into prompt
120
- # (similar to chat template)
121
- def _get_full_image_text_prompt(self, image_token_str: str,
122
- text_prompt: str) -> str:
123
- """Combine image and text prompts for vision language model"""
124
-
125
- # NOTE: For now we assume all model architectures use the same
126
- # image + text prompt format. This may change in the future.
127
- return f"{image_token_str}\n{text_prompt}"
128
-
129
- def _parse_chat_message_content_parts(
130
- self,
131
- role: str,
132
- parts: Iterable[ChatCompletionContentPartParam],
133
- ) -> ChatMessageParseResult:
134
- texts: List[str] = []
135
- mm_futures: List[Awaitable[MultiModalDataDict]] = []
136
-
137
- for part in parts:
138
- part_type = part["type"]
139
- if part_type == "text":
140
- text = cast(ChatCompletionContentPartTextParam, part)["text"]
141
- texts.append(text)
142
- elif part_type == "image_url":
143
- if len(mm_futures) > 0:
144
- raise NotImplementedError(
145
- "Multiple 'image_url' input is currently not supported."
146
- )
147
-
148
- image_url = cast(ChatCompletionContentPartImageParam,
149
- part)["image_url"]
150
-
151
- if image_url.get("detail", "auto") != "auto":
152
- logger.warning(
153
- "'image_url.detail' is currently not supported and "
154
- "will be ignored.")
155
-
156
- image_future = async_get_and_parse_image(image_url["url"])
157
- mm_futures.append(image_future)
158
- else:
159
- raise NotImplementedError(f"Unknown part type: {part_type}")
160
-
161
- text_prompt = "\n".join(texts)
162
-
163
- if mm_futures:
164
- image_token_str = self.image_token_str
165
- if image_token_str is not None:
166
- if image_token_str in text_prompt:
167
- logger.warning(
168
- "Detected image token string in the text prompt. "
169
- "Skipping prompt formatting.")
170
- else:
171
- text_prompt = self._get_full_image_text_prompt(
172
- image_token_str=image_token_str,
173
- text_prompt=text_prompt,
174
- )
175
 
176
- messages = [ConversationMessage(role=role, content=text_prompt)]
177
-
178
- return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
179
-
180
- def _parse_chat_message_content(
181
- self,
182
- message: ChatCompletionMessageParam,
183
- ) -> ChatMessageParseResult:
184
- role = message["role"]
185
- content = message.get("content")
186
-
187
- if content is None:
188
- return ChatMessageParseResult(messages=[], mm_futures=[])
189
- if isinstance(content, str):
190
- messages = [ConversationMessage(role=role, content=content)]
191
- return ChatMessageParseResult(messages=messages, mm_futures=[])
192
-
193
- return self._parse_chat_message_content_parts(role, content)
194
 
195
  async def create_chat_completion(
196
  self,
@@ -212,11 +83,20 @@ class OpenAIServingChat(OpenAIServing):
212
  return error_check_ret
213
 
214
  try:
 
 
 
 
 
 
 
 
215
  conversation: List[ConversationMessage] = []
216
  mm_futures: List[Awaitable[MultiModalDataDict]] = []
217
 
218
  for msg in request.messages:
219
- chat_parsed_result = self._parse_chat_message_content(msg)
 
220
 
221
  conversation.extend(chat_parsed_result.messages)
222
  mm_futures.extend(chat_parsed_result.mm_futures)
@@ -225,13 +105,13 @@ class OpenAIServingChat(OpenAIServing):
225
  tool.model_dump() for tool in request.tools
226
  ]
227
 
228
- prompt = self.tokenizer.apply_chat_template(
229
  conversation=conversation,
230
  tokenize=False,
231
  add_generation_prompt=request.add_generation_prompt,
232
  tools=tool_dicts,
233
  documents=request.documents,
234
- chat_template=request.chat_template,
235
  **(request.chat_template_kwargs or {}),
236
  )
237
  except Exception as e:
@@ -250,61 +130,71 @@ class OpenAIServingChat(OpenAIServing):
250
  logger.error("Error in loading multi-modal data: %s", e)
251
  return self.create_error_response(str(e))
252
 
253
- request_id = f"cmpl-{random_uuid()}"
254
  try:
255
- # Tokenize/detokenize depending on prompt format (string/token list)
256
- prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
257
- request,
258
- prompt=prompt,
259
- add_special_tokens=request.add_special_tokens)
260
  sampling_params = request.to_sampling_params()
261
- lora_request = self._maybe_get_lora(request)
262
  decoding_config = await self.engine.get_decoding_config()
263
  guided_decoding_backend = request.guided_decoding_backend \
264
  or decoding_config.guided_decoding_backend
265
  guided_decode_logits_processor = (
266
- await get_guided_decoding_logits_processor(
267
- guided_decoding_backend, request, await
268
- self.engine.get_tokenizer()))
269
  if guided_decode_logits_processor:
270
  if sampling_params.logits_processors is None:
271
  sampling_params.logits_processors = []
272
  sampling_params.logits_processors.append(
273
  guided_decode_logits_processor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  except ValueError as e:
 
275
  return self.create_error_response(str(e))
276
 
277
- inputs: PromptInputs = {
278
- "prompt": prompt_text,
279
- "prompt_token_ids": prompt_ids,
280
- }
281
- if mm_data:
282
- inputs["multi_modal_data"] = mm_data
283
-
284
- is_tracing_enabled = await self.engine.is_tracing_enabled()
285
- trace_headers = None
286
- if is_tracing_enabled and raw_request:
287
- trace_headers = extract_trace_headers(raw_request.headers)
288
- if not is_tracing_enabled and raw_request and contains_trace_headers(
289
- raw_request.headers):
290
- log_tracing_disabled_warning()
291
-
292
- result_generator = self.engine.generate(
293
- inputs,
294
- sampling_params,
295
- request_id,
296
- lora_request,
297
- trace_headers=trace_headers,
298
- )
299
  # Streaming response
300
  if request.stream:
301
  return self.chat_completion_stream_generator(
302
- request, result_generator, request_id, conversation)
303
  else:
304
  try:
305
  return await self.chat_completion_full_generator(
306
  request, raw_request, result_generator, request_id,
307
- conversation)
308
  except ValueError as e:
309
  # TODO: Use a vllm-specific Validation Error
310
  return self.create_error_response(str(e))
@@ -316,9 +206,12 @@ class OpenAIServingChat(OpenAIServing):
316
  return request.messages[-1]["role"]
317
 
318
  async def chat_completion_stream_generator(
319
- self, request: ChatCompletionRequest,
320
- result_generator: AsyncIterator[RequestOutput], request_id: str,
321
- conversation: List[ConversationMessage]
 
 
 
322
  ) -> AsyncGenerator[str, None]:
323
  model_name = self.served_model_names[0]
324
  created_time = int(time.time())
@@ -326,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
326
  first_iteration = True
327
 
328
  # Send response for each token for each request.n (index)
329
- assert request.n is not None
330
- previous_texts = [""] * request.n
331
- previous_num_tokens = [0] * request.n
332
- finish_reason_sent = [False] * request.n
 
333
  try:
334
  async for res in result_generator:
335
  # We need to do it here, because if there are exceptions in
@@ -339,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
339
  # Send first response for each request.n (index) with
340
  # the role
341
  role = self.get_chat_request_role(request)
342
- for i in range(request.n):
343
  choice_data = ChatCompletionResponseStreamChoice(
344
  index=i,
345
  delta=DeltaMessage(role=role),
@@ -367,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
367
  last_msg_content = conversation[-1]["content"]
368
 
369
  if last_msg_content:
370
- for i in range(request.n):
371
  choice_data = (
372
  ChatCompletionResponseStreamChoice(
373
  index=i,
374
  delta=DeltaMessage(
375
  content=last_msg_content),
 
376
  finish_reason=None))
377
  chunk = ChatCompletionStreamResponse(
378
  id=request_id,
379
  object=chunk_object_type,
380
  created=created_time,
381
  choices=[choice_data],
382
- logprobs=None,
383
  model=model_name)
384
  if (request.stream_options and
385
  request.stream_options.include_usage):
@@ -405,6 +299,7 @@ class OpenAIServingChat(OpenAIServing):
405
  logprobs = self._create_chat_logprobs(
406
  token_ids=delta_token_ids,
407
  top_logprobs=out_logprobs,
 
408
  num_output_top_logprobs=request.top_logprobs,
409
  )
410
  else:
@@ -493,9 +388,13 @@ class OpenAIServingChat(OpenAIServing):
493
  yield "data: [DONE]\n\n"
494
 
495
  async def chat_completion_full_generator(
496
- self, request: ChatCompletionRequest, raw_request: Optional[Request],
497
- result_generator: AsyncIterator[RequestOutput], request_id: str,
498
- conversation: List[ConversationMessage]
 
 
 
 
499
  ) -> Union[ErrorResponse, ChatCompletionResponse]:
500
 
501
  model_name = self.served_model_names[0]
@@ -523,6 +422,7 @@ class OpenAIServingChat(OpenAIServing):
523
  token_ids=token_ids,
524
  top_logprobs=out_logprobs,
525
  num_output_top_logprobs=request.top_logprobs,
 
526
  )
527
  else:
528
  logprobs = None
@@ -577,16 +477,14 @@ class OpenAIServingChat(OpenAIServing):
577
  return response
578
 
579
  def _get_top_logprobs(
580
- self, logprobs: Dict[int, Logprob],
581
- top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
582
  return [
583
  ChatCompletionLogProb(
584
- token=self._get_decoded_token(p[1], p[0]),
 
585
  logprob=max(p[1].logprob, -9999.0),
586
- bytes=list(
587
- self._get_decoded_token(p[1],
588
- p[0]).encode("utf-8",
589
- errors="replace")))
590
  for i, p in enumerate(logprobs.items())
591
  if top_logprobs and i < top_logprobs
592
  ]
@@ -595,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
595
  self,
596
  token_ids: GenericSequence[int],
597
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
 
598
  num_output_top_logprobs: Optional[int] = None,
599
  ) -> ChatCompletionLogProbs:
600
  """Create OpenAI-style logprobs."""
@@ -604,12 +503,11 @@ class OpenAIServingChat(OpenAIServing):
604
  for i, token_id in enumerate(token_ids):
605
  step_top_logprobs = top_logprobs[i]
606
  if step_top_logprobs is None:
 
607
  logprobs_content.append(
608
  ChatCompletionLogProbsContent(
609
- token=self.tokenizer.decode(token_id),
610
- bytes=list(
611
- self.tokenizer.decode(token_id).encode(
612
- "utf-8", errors="replace"))))
613
  else:
614
  logprobs_content.append(
615
  ChatCompletionLogProbsContent(
@@ -620,6 +518,7 @@ class OpenAIServingChat(OpenAIServing):
620
  step_top_logprobs[token_id].decoded_token.encode(
621
  "utf-8", errors="replace")),
622
  top_logprobs=self._get_top_logprobs(
623
- step_top_logprobs, num_output_top_logprobs)))
 
624
 
625
  return ChatCompletionLogProbs(content=logprobs_content)
 
 
1
  import time
2
+ from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
3
+ Optional)
 
 
4
  from typing import Sequence as GenericSequence
5
+ from typing import Union
6
 
7
  from fastapi import Request
8
+ from transformers import PreTrainedTokenizer
 
9
 
10
  from vllm.config import ModelConfig
11
  from vllm.engine.async_llm_engine import AsyncLLMEngine
12
+ from vllm.entrypoints.chat_utils import (ConversationMessage,
13
+ load_chat_template,
14
+ parse_chat_message_content)
15
+ from vllm.entrypoints.logger import RequestLogger
16
  from vllm.entrypoints.openai.protocol import (
17
+ ChatCompletionLogProb, ChatCompletionLogProbs,
18
+ ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
 
19
  ChatCompletionRequest, ChatCompletionResponse,
20
  ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
21
  ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
22
  FunctionCall, ToolCall, UsageInfo)
23
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
24
+ OpenAIServing,
25
+ PromptAdapterPath)
26
  from vllm.inputs import PromptInputs
27
  from vllm.logger import init_logger
28
  from vllm.model_executor.guided_decoding import (
29
  get_guided_decoding_logits_processor)
30
  from vllm.multimodal import MultiModalDataDict
 
31
  from vllm.outputs import RequestOutput
32
  from vllm.sequence import Logprob
33
  from vllm.tracing import (contains_trace_headers, extract_trace_headers,
 
37
  logger = init_logger(__name__)
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class OpenAIServingChat(OpenAIServing):
41
 
42
+ def __init__(
43
+ self,
44
+ engine: AsyncLLMEngine,
45
+ model_config: ModelConfig,
46
+ served_model_names: List[str],
47
+ response_role: str,
48
+ *,
49
+ lora_modules: Optional[List[LoRAModulePath]],
50
+ prompt_adapters: Optional[List[PromptAdapterPath]],
51
+ request_logger: Optional[RequestLogger],
52
+ chat_template: Optional[str],
53
+ ):
54
  super().__init__(engine=engine,
55
  model_config=model_config,
56
  served_model_names=served_model_names,
57
+ lora_modules=lora_modules,
58
+ prompt_adapters=prompt_adapters,
59
+ request_logger=request_logger)
60
 
61
  self.response_role = response_role
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # If this is None we use the tokenizer's default chat template
64
+ self.chat_template = load_chat_template(chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  async def create_chat_completion(
67
  self,
 
83
  return error_check_ret
84
 
85
  try:
86
+ (
87
+ lora_request,
88
+ prompt_adapter_request,
89
+ ) = self._maybe_get_adapters(request)
90
+
91
+ model_config = self.model_config
92
+ tokenizer = await self.engine.get_tokenizer(lora_request)
93
+
94
  conversation: List[ConversationMessage] = []
95
  mm_futures: List[Awaitable[MultiModalDataDict]] = []
96
 
97
  for msg in request.messages:
98
+ chat_parsed_result = parse_chat_message_content(
99
+ msg, model_config, tokenizer)
100
 
101
  conversation.extend(chat_parsed_result.messages)
102
  mm_futures.extend(chat_parsed_result.mm_futures)
 
105
  tool.model_dump() for tool in request.tools
106
  ]
107
 
108
+ prompt = tokenizer.apply_chat_template(
109
  conversation=conversation,
110
  tokenize=False,
111
  add_generation_prompt=request.add_generation_prompt,
112
  tools=tool_dicts,
113
  documents=request.documents,
114
+ chat_template=request.chat_template or self.chat_template,
115
  **(request.chat_template_kwargs or {}),
116
  )
117
  except Exception as e:
 
130
  logger.error("Error in loading multi-modal data: %s", e)
131
  return self.create_error_response(str(e))
132
 
133
+ request_id = f"chat-{random_uuid()}"
134
  try:
 
 
 
 
 
135
  sampling_params = request.to_sampling_params()
 
136
  decoding_config = await self.engine.get_decoding_config()
137
  guided_decoding_backend = request.guided_decoding_backend \
138
  or decoding_config.guided_decoding_backend
139
  guided_decode_logits_processor = (
140
+ await
141
+ get_guided_decoding_logits_processor(guided_decoding_backend,
142
+ request, tokenizer))
143
  if guided_decode_logits_processor:
144
  if sampling_params.logits_processors is None:
145
  sampling_params.logits_processors = []
146
  sampling_params.logits_processors.append(
147
  guided_decode_logits_processor)
148
+
149
+ prompt_inputs = self._tokenize_prompt_input(
150
+ request,
151
+ tokenizer,
152
+ prompt,
153
+ truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
154
+ add_special_tokens=request.add_special_tokens,
155
+ )
156
+
157
+ self._log_inputs(request_id,
158
+ prompt_inputs,
159
+ params=sampling_params,
160
+ lora_request=lora_request,
161
+ prompt_adapter_request=prompt_adapter_request)
162
+
163
+ engine_inputs: PromptInputs = {
164
+ "prompt_token_ids": prompt_inputs["prompt_token_ids"],
165
+ }
166
+ if mm_data is not None:
167
+ engine_inputs["multi_modal_data"] = mm_data
168
+
169
+ is_tracing_enabled = await self.engine.is_tracing_enabled()
170
+ trace_headers = None
171
+ if is_tracing_enabled and raw_request:
172
+ trace_headers = extract_trace_headers(raw_request.headers)
173
+ if (not is_tracing_enabled and raw_request
174
+ and contains_trace_headers(raw_request.headers)):
175
+ log_tracing_disabled_warning()
176
+
177
+ result_generator = self.engine.generate(
178
+ engine_inputs,
179
+ sampling_params,
180
+ request_id,
181
+ lora_request=lora_request,
182
+ trace_headers=trace_headers,
183
+ prompt_adapter_request=prompt_adapter_request,
184
+ )
185
  except ValueError as e:
186
+ # TODO: Use a vllm-specific Validation Error
187
  return self.create_error_response(str(e))
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # Streaming response
190
  if request.stream:
191
  return self.chat_completion_stream_generator(
192
+ request, result_generator, request_id, conversation, tokenizer)
193
  else:
194
  try:
195
  return await self.chat_completion_full_generator(
196
  request, raw_request, result_generator, request_id,
197
+ conversation, tokenizer)
198
  except ValueError as e:
199
  # TODO: Use a vllm-specific Validation Error
200
  return self.create_error_response(str(e))
 
206
  return request.messages[-1]["role"]
207
 
208
  async def chat_completion_stream_generator(
209
+ self,
210
+ request: ChatCompletionRequest,
211
+ result_generator: AsyncIterator[RequestOutput],
212
+ request_id: str,
213
+ conversation: List[ConversationMessage],
214
+ tokenizer: PreTrainedTokenizer,
215
  ) -> AsyncGenerator[str, None]:
216
  model_name = self.served_model_names[0]
217
  created_time = int(time.time())
 
219
  first_iteration = True
220
 
221
  # Send response for each token for each request.n (index)
222
+ num_choices = 1 if request.n is None else request.n
223
+ previous_texts = [""] * num_choices
224
+ previous_num_tokens = [0] * num_choices
225
+ finish_reason_sent = [False] * num_choices
226
+
227
  try:
228
  async for res in result_generator:
229
  # We need to do it here, because if there are exceptions in
 
233
  # Send first response for each request.n (index) with
234
  # the role
235
  role = self.get_chat_request_role(request)
236
+ for i in range(num_choices):
237
  choice_data = ChatCompletionResponseStreamChoice(
238
  index=i,
239
  delta=DeltaMessage(role=role),
 
261
  last_msg_content = conversation[-1]["content"]
262
 
263
  if last_msg_content:
264
+ for i in range(num_choices):
265
  choice_data = (
266
  ChatCompletionResponseStreamChoice(
267
  index=i,
268
  delta=DeltaMessage(
269
  content=last_msg_content),
270
+ logprobs=None,
271
  finish_reason=None))
272
  chunk = ChatCompletionStreamResponse(
273
  id=request_id,
274
  object=chunk_object_type,
275
  created=created_time,
276
  choices=[choice_data],
 
277
  model=model_name)
278
  if (request.stream_options and
279
  request.stream_options.include_usage):
 
299
  logprobs = self._create_chat_logprobs(
300
  token_ids=delta_token_ids,
301
  top_logprobs=out_logprobs,
302
+ tokenizer=tokenizer,
303
  num_output_top_logprobs=request.top_logprobs,
304
  )
305
  else:
 
388
  yield "data: [DONE]\n\n"
389
 
390
  async def chat_completion_full_generator(
391
+ self,
392
+ request: ChatCompletionRequest,
393
+ raw_request: Optional[Request],
394
+ result_generator: AsyncIterator[RequestOutput],
395
+ request_id: str,
396
+ conversation: List[ConversationMessage],
397
+ tokenizer: PreTrainedTokenizer,
398
  ) -> Union[ErrorResponse, ChatCompletionResponse]:
399
 
400
  model_name = self.served_model_names[0]
 
422
  token_ids=token_ids,
423
  top_logprobs=out_logprobs,
424
  num_output_top_logprobs=request.top_logprobs,
425
+ tokenizer=tokenizer,
426
  )
427
  else:
428
  logprobs = None
 
477
  return response
478
 
479
  def _get_top_logprobs(
480
+ self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
481
+ tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
482
  return [
483
  ChatCompletionLogProb(
484
+ token=(token := self._get_decoded_token(p[1], p[0],
485
+ tokenizer)),
486
  logprob=max(p[1].logprob, -9999.0),
487
+ bytes=list(token.encode("utf-8", errors="replace")))
 
 
 
488
  for i, p in enumerate(logprobs.items())
489
  if top_logprobs and i < top_logprobs
490
  ]
 
493
  self,
494
  token_ids: GenericSequence[int],
495
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
496
+ tokenizer: PreTrainedTokenizer,
497
  num_output_top_logprobs: Optional[int] = None,
498
  ) -> ChatCompletionLogProbs:
499
  """Create OpenAI-style logprobs."""
 
503
  for i, token_id in enumerate(token_ids):
504
  step_top_logprobs = top_logprobs[i]
505
  if step_top_logprobs is None:
506
+ token = tokenizer.decode(token_id)
507
  logprobs_content.append(
508
  ChatCompletionLogProbsContent(
509
+ token=token,
510
+ bytes=list(token.encode("utf-8", errors="replace"))))
 
 
511
  else:
512
  logprobs_content.append(
513
  ChatCompletionLogProbsContent(
 
518
  step_top_logprobs[token_id].decoded_token.encode(
519
  "utf-8", errors="replace")),
520
  top_logprobs=self._get_top_logprobs(
521
+ step_top_logprobs, num_output_top_logprobs,
522
+ tokenizer)))
523
 
524
  return ChatCompletionLogProbs(content=logprobs_content)
serving_completion.py CHANGED
@@ -2,12 +2,14 @@ import time
2
  from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
  Optional)
4
  from typing import Sequence as GenericSequence
5
- from typing import Tuple
6
 
7
  from fastapi import Request
 
8
 
9
  from vllm.config import ModelConfig
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
 
11
  # yapf conflicts with isort for this block
12
  # yapf: disable
13
  from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
@@ -16,13 +18,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
16
  CompletionResponseChoice,
17
  CompletionResponseStreamChoice,
18
  CompletionStreamResponse,
19
- DetokenizeRequest,
20
- DetokenizeResponse,
21
- TokenizeRequest,
22
- TokenizeResponse, UsageInfo)
23
  # yapf: enable
24
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
25
- OpenAIServing)
 
26
  from vllm.logger import init_logger
27
  from vllm.model_executor.guided_decoding import (
28
  get_guided_decoding_logits_processor)
@@ -40,38 +40,24 @@ TypeCreateLogProbsFn = Callable[
40
  [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
41
 
42
 
43
- def parse_prompt_format(prompt) -> Tuple[bool, list]:
44
- # get the prompt, openai supports the following
45
- # "a string, array of strings, array of tokens, or array of token arrays."
46
- prompt_is_tokens = False
47
- prompts = [prompt] # case 1: a string
48
- if isinstance(prompt, list):
49
- if len(prompt) == 0:
50
- raise ValueError("please provide at least one prompt")
51
- elif isinstance(prompt[0], str):
52
- prompt_is_tokens = False
53
- prompts = prompt # case 2: array of strings
54
- elif isinstance(prompt[0], int):
55
- prompt_is_tokens = True
56
- prompts = [prompt] # case 3: array of tokens
57
- elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
58
- prompt_is_tokens = True
59
- prompts = prompt # case 4: array of token arrays
60
- else:
61
- raise ValueError("prompt must be a string, array of strings, "
62
- "array of tokens, or array of token arrays")
63
- return prompt_is_tokens, prompts
64
-
65
-
66
  class OpenAIServingCompletion(OpenAIServing):
67
 
68
- def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
69
- served_model_names: List[str],
70
- lora_modules: Optional[List[LoRAModulePath]]):
 
 
 
 
 
 
 
71
  super().__init__(engine=engine,
72
  model_config=model_config,
73
  served_model_names=served_model_names,
74
- lora_modules=lora_modules)
 
 
75
 
76
  async def create_completion(self, request: CompletionRequest,
77
  raw_request: Request):
@@ -100,36 +86,45 @@ class OpenAIServingCompletion(OpenAIServing):
100
  # Schedule the request and get the result generator.
101
  generators: List[AsyncIterator[RequestOutput]] = []
102
  try:
 
 
 
 
 
 
 
103
  sampling_params = request.to_sampling_params()
104
- lora_request = self._maybe_get_lora(request)
105
  decoding_config = await self.engine.get_decoding_config()
106
  guided_decoding_backend = request.guided_decoding_backend \
107
  or decoding_config.guided_decoding_backend
108
  guided_decode_logit_processor = (
109
- await get_guided_decoding_logits_processor(
110
- guided_decoding_backend, request, await
111
- self.engine.get_tokenizer()))
112
  if guided_decode_logit_processor is not None:
113
  if sampling_params.logits_processors is None:
114
  sampling_params.logits_processors = []
115
  sampling_params.logits_processors.append(
116
  guided_decode_logit_processor)
117
- prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
118
-
119
- for i, prompt in enumerate(prompts):
120
- if prompt_is_tokens:
121
- prompt_formats = self._validate_prompt_and_tokenize(
122
- request,
123
- prompt_ids=prompt,
124
- truncate_prompt_tokens=sampling_params.
125
- truncate_prompt_tokens)
126
- else:
127
- prompt_formats = self._validate_prompt_and_tokenize(
128
- request,
129
- prompt=prompt,
130
- truncate_prompt_tokens=sampling_params.
131
- truncate_prompt_tokens)
132
- prompt_ids, prompt_text = prompt_formats
 
 
 
133
 
134
  is_tracing_enabled = await self.engine.is_tracing_enabled()
135
  trace_headers = None
@@ -140,13 +135,11 @@ class OpenAIServingCompletion(OpenAIServing):
140
  log_tracing_disabled_warning()
141
 
142
  generator = self.engine.generate(
143
- {
144
- "prompt": prompt_text,
145
- "prompt_token_ids": prompt_ids
146
- },
147
  sampling_params,
148
- f"{request_id}-{i}",
149
  lora_request=lora_request,
 
150
  trace_headers=trace_headers,
151
  )
152
 
@@ -173,7 +166,8 @@ class OpenAIServingCompletion(OpenAIServing):
173
  request_id,
174
  created_time,
175
  model_name,
176
- num_prompts=len(prompts))
 
177
 
178
  # Non-streaming response
179
  final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@@ -184,8 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
184
  await self.engine.abort(f"{request_id}-{i}")
185
  return self.create_error_response("Client disconnected")
186
  final_res_batch[i] = res
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  response = self.request_output_to_completion_response(
188
- final_res_batch, request, request_id, created_time, model_name)
 
 
 
 
 
 
189
  except ValueError as e:
190
  # TODO: Use a vllm-specific Validation Error
191
  return self.create_error_response(str(e))
@@ -212,11 +225,12 @@ class OpenAIServingCompletion(OpenAIServing):
212
  created_time: int,
213
  model_name: str,
214
  num_prompts: int,
 
215
  ) -> AsyncGenerator[str, None]:
216
- assert request.n is not None
217
- previous_texts = [""] * request.n * num_prompts
218
- previous_num_tokens = [0] * request.n * num_prompts
219
- has_echoed = [False] * request.n * num_prompts
220
 
221
  try:
222
  async for prompt_idx, res in result_generator:
@@ -227,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
227
  raise StopAsyncIteration()
228
 
229
  for output in res.outputs:
230
- i = output.index + prompt_idx * request.n
231
  # TODO(simon): optimize the performance by avoiding full
232
  # text O(n^2) sending.
233
 
@@ -262,6 +276,7 @@ class OpenAIServingCompletion(OpenAIServing):
262
  token_ids=delta_token_ids,
263
  top_logprobs=out_logprobs,
264
  num_output_top_logprobs=request.logprobs,
 
265
  initial_text_offset=len(previous_texts[i]),
266
  )
267
  else:
@@ -301,7 +316,7 @@ class OpenAIServingCompletion(OpenAIServing):
301
  else:
302
  chunk.usage = None
303
 
304
- response_json = chunk.model_dump_json(exclude_unset=True)
305
  yield f"data: {response_json}\n\n"
306
 
307
  if (request.stream_options
@@ -314,7 +329,7 @@ class OpenAIServingCompletion(OpenAIServing):
314
  usage=usage,
315
  )
316
  final_usage_data = (final_usage_chunk.model_dump_json(
317
- exclude_unset=True, exclude_none=True))
318
  yield f"data: {final_usage_data}\n\n"
319
 
320
  except ValueError as e:
@@ -330,12 +345,13 @@ class OpenAIServingCompletion(OpenAIServing):
330
  request_id: str,
331
  created_time: int,
332
  model_name: str,
 
333
  ) -> CompletionResponse:
334
  choices: List[CompletionResponseChoice] = []
335
  num_prompt_tokens = 0
336
  num_generated_tokens = 0
 
337
  for final_res in final_res_batch:
338
- assert final_res is not None
339
  prompt_token_ids = final_res.prompt_token_ids
340
  prompt_logprobs = final_res.prompt_logprobs
341
  prompt_text = final_res.prompt
@@ -361,6 +377,7 @@ class OpenAIServingCompletion(OpenAIServing):
361
  logprobs = self._create_completion_logprobs(
362
  token_ids=token_ids,
363
  top_logprobs=out_logprobs,
 
364
  num_output_top_logprobs=request.logprobs,
365
  )
366
  else:
@@ -398,6 +415,7 @@ class OpenAIServingCompletion(OpenAIServing):
398
  token_ids: GenericSequence[int],
399
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
400
  num_output_top_logprobs: int,
 
401
  initial_text_offset: int = 0,
402
  ) -> CompletionLogProbs:
403
  """Create logprobs for OpenAI Completion API."""
@@ -411,13 +429,13 @@ class OpenAIServingCompletion(OpenAIServing):
411
  for i, token_id in enumerate(token_ids):
412
  step_top_logprobs = top_logprobs[i]
413
  if step_top_logprobs is None:
414
- token = self.tokenizer.decode(token_id)
415
  out_tokens.append(token)
416
  out_token_logprobs.append(None)
417
  out_top_logprobs.append(None)
418
  else:
419
  token = self._get_decoded_token(step_top_logprobs[token_id],
420
- token_id)
421
  token_logprob = max(step_top_logprobs[token_id].logprob,
422
  -9999.0)
423
  out_tokens.append(token)
@@ -430,7 +448,7 @@ class OpenAIServingCompletion(OpenAIServing):
430
  out_top_logprobs.append({
431
  # Convert float("-inf") to the
432
  # JSON-serializable float that OpenAI uses
433
- self._get_decoded_token(top_lp[1], top_lp[0]):
434
  max(top_lp[1].logprob, -9999.0)
435
  for i, top_lp in enumerate(step_top_logprobs.items())
436
  if num_output_top_logprobs >= i
@@ -447,30 +465,4 @@ class OpenAIServingCompletion(OpenAIServing):
447
  token_logprobs=out_token_logprobs,
448
  tokens=out_tokens,
449
  top_logprobs=out_top_logprobs,
450
- )
451
-
452
- async def create_tokenize(self,
453
- request: TokenizeRequest) -> TokenizeResponse:
454
- error_check_ret = await self._check_model(request)
455
- if error_check_ret is not None:
456
- return error_check_ret
457
-
458
- (input_ids, input_text) = self._validate_prompt_and_tokenize(
459
- request,
460
- prompt=request.prompt,
461
- add_special_tokens=request.add_special_tokens)
462
-
463
- return TokenizeResponse(tokens=input_ids,
464
- count=len(input_ids),
465
- max_model_len=self.max_model_len)
466
-
467
- async def create_detokenize(
468
- self, request: DetokenizeRequest) -> DetokenizeResponse:
469
- error_check_ret = await self._check_model(request)
470
- if error_check_ret is not None:
471
- return error_check_ret
472
-
473
- (input_ids, input_text) = self._validate_prompt_and_tokenize(
474
- request, prompt_ids=request.tokens)
475
-
476
- return DetokenizeResponse(prompt=input_text)
 
2
  from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
  Optional)
4
  from typing import Sequence as GenericSequence
5
+ from typing import Tuple, cast
6
 
7
  from fastapi import Request
8
+ from transformers import PreTrainedTokenizer
9
 
10
  from vllm.config import ModelConfig
11
  from vllm.engine.async_llm_engine import AsyncLLMEngine
12
+ from vllm.entrypoints.logger import RequestLogger
13
  # yapf conflicts with isort for this block
14
  # yapf: disable
15
  from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
 
18
  CompletionResponseChoice,
19
  CompletionResponseStreamChoice,
20
  CompletionStreamResponse,
21
+ UsageInfo)
 
 
 
22
  # yapf: enable
23
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
24
+ OpenAIServing,
25
+ PromptAdapterPath)
26
  from vllm.logger import init_logger
27
  from vllm.model_executor.guided_decoding import (
28
  get_guided_decoding_logits_processor)
 
40
  [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class OpenAIServingCompletion(OpenAIServing):
44
 
45
+ def __init__(
46
+ self,
47
+ engine: AsyncLLMEngine,
48
+ model_config: ModelConfig,
49
+ served_model_names: List[str],
50
+ *,
51
+ lora_modules: Optional[List[LoRAModulePath]],
52
+ prompt_adapters: Optional[List[PromptAdapterPath]],
53
+ request_logger: Optional[RequestLogger],
54
+ ):
55
  super().__init__(engine=engine,
56
  model_config=model_config,
57
  served_model_names=served_model_names,
58
+ lora_modules=lora_modules,
59
+ prompt_adapters=prompt_adapters,
60
+ request_logger=request_logger)
61
 
62
  async def create_completion(self, request: CompletionRequest,
63
  raw_request: Request):
 
86
  # Schedule the request and get the result generator.
87
  generators: List[AsyncIterator[RequestOutput]] = []
88
  try:
89
+ (
90
+ lora_request,
91
+ prompt_adapter_request,
92
+ ) = self._maybe_get_adapters(request)
93
+
94
+ tokenizer = await self.engine.get_tokenizer(lora_request)
95
+
96
  sampling_params = request.to_sampling_params()
 
97
  decoding_config = await self.engine.get_decoding_config()
98
  guided_decoding_backend = request.guided_decoding_backend \
99
  or decoding_config.guided_decoding_backend
100
  guided_decode_logit_processor = (
101
+ await
102
+ get_guided_decoding_logits_processor(guided_decoding_backend,
103
+ request, tokenizer))
104
  if guided_decode_logit_processor is not None:
105
  if sampling_params.logits_processors is None:
106
  sampling_params.logits_processors = []
107
  sampling_params.logits_processors.append(
108
  guided_decode_logit_processor)
109
+
110
+ prompts = list(
111
+ self._tokenize_prompt_input_or_inputs(
112
+ request,
113
+ tokenizer,
114
+ request.prompt,
115
+ truncate_prompt_tokens=sampling_params.
116
+ truncate_prompt_tokens,
117
+ add_special_tokens=request.add_special_tokens,
118
+ ))
119
+
120
+ for i, prompt_inputs in enumerate(prompts):
121
+ request_id_item = f"{request_id}-{i}"
122
+
123
+ self._log_inputs(request_id_item,
124
+ prompt_inputs,
125
+ params=sampling_params,
126
+ lora_request=lora_request,
127
+ prompt_adapter_request=prompt_adapter_request)
128
 
129
  is_tracing_enabled = await self.engine.is_tracing_enabled()
130
  trace_headers = None
 
135
  log_tracing_disabled_warning()
136
 
137
  generator = self.engine.generate(
138
+ {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
 
 
 
139
  sampling_params,
140
+ request_id_item,
141
  lora_request=lora_request,
142
+ prompt_adapter_request=prompt_adapter_request,
143
  trace_headers=trace_headers,
144
  )
145
 
 
166
  request_id,
167
  created_time,
168
  model_name,
169
+ num_prompts=len(prompts),
170
+ tokenizer=tokenizer)
171
 
172
  # Non-streaming response
173
  final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
 
178
  await self.engine.abort(f"{request_id}-{i}")
179
  return self.create_error_response("Client disconnected")
180
  final_res_batch[i] = res
181
+
182
+ for i, final_res in enumerate(final_res_batch):
183
+ assert final_res is not None
184
+
185
+ # The output should contain the input text
186
+ # We did not pass it into vLLM engine to avoid being redundant
187
+ # with the inputs token IDs
188
+ if final_res.prompt is None:
189
+ final_res.prompt = prompts[i]["prompt"]
190
+
191
+ final_res_batch_checked = cast(List[RequestOutput],
192
+ final_res_batch)
193
+
194
  response = self.request_output_to_completion_response(
195
+ final_res_batch_checked,
196
+ request,
197
+ request_id,
198
+ created_time,
199
+ model_name,
200
+ tokenizer,
201
+ )
202
  except ValueError as e:
203
  # TODO: Use a vllm-specific Validation Error
204
  return self.create_error_response(str(e))
 
225
  created_time: int,
226
  model_name: str,
227
  num_prompts: int,
228
+ tokenizer: PreTrainedTokenizer,
229
  ) -> AsyncGenerator[str, None]:
230
+ num_choices = 1 if request.n is None else request.n
231
+ previous_texts = [""] * num_choices * num_prompts
232
+ previous_num_tokens = [0] * num_choices * num_prompts
233
+ has_echoed = [False] * num_choices * num_prompts
234
 
235
  try:
236
  async for prompt_idx, res in result_generator:
 
241
  raise StopAsyncIteration()
242
 
243
  for output in res.outputs:
244
+ i = output.index + prompt_idx * num_choices
245
  # TODO(simon): optimize the performance by avoiding full
246
  # text O(n^2) sending.
247
 
 
276
  token_ids=delta_token_ids,
277
  top_logprobs=out_logprobs,
278
  num_output_top_logprobs=request.logprobs,
279
+ tokenizer=tokenizer,
280
  initial_text_offset=len(previous_texts[i]),
281
  )
282
  else:
 
316
  else:
317
  chunk.usage = None
318
 
319
+ response_json = chunk.model_dump_json(exclude_unset=False)
320
  yield f"data: {response_json}\n\n"
321
 
322
  if (request.stream_options
 
329
  usage=usage,
330
  )
331
  final_usage_data = (final_usage_chunk.model_dump_json(
332
+ exclude_unset=False, exclude_none=True))
333
  yield f"data: {final_usage_data}\n\n"
334
 
335
  except ValueError as e:
 
345
  request_id: str,
346
  created_time: int,
347
  model_name: str,
348
+ tokenizer: PreTrainedTokenizer,
349
  ) -> CompletionResponse:
350
  choices: List[CompletionResponseChoice] = []
351
  num_prompt_tokens = 0
352
  num_generated_tokens = 0
353
+
354
  for final_res in final_res_batch:
 
355
  prompt_token_ids = final_res.prompt_token_ids
356
  prompt_logprobs = final_res.prompt_logprobs
357
  prompt_text = final_res.prompt
 
377
  logprobs = self._create_completion_logprobs(
378
  token_ids=token_ids,
379
  top_logprobs=out_logprobs,
380
+ tokenizer=tokenizer,
381
  num_output_top_logprobs=request.logprobs,
382
  )
383
  else:
 
415
  token_ids: GenericSequence[int],
416
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
417
  num_output_top_logprobs: int,
418
+ tokenizer: PreTrainedTokenizer,
419
  initial_text_offset: int = 0,
420
  ) -> CompletionLogProbs:
421
  """Create logprobs for OpenAI Completion API."""
 
429
  for i, token_id in enumerate(token_ids):
430
  step_top_logprobs = top_logprobs[i]
431
  if step_top_logprobs is None:
432
+ token = tokenizer.decode(token_id)
433
  out_tokens.append(token)
434
  out_token_logprobs.append(None)
435
  out_top_logprobs.append(None)
436
  else:
437
  token = self._get_decoded_token(step_top_logprobs[token_id],
438
+ token_id, tokenizer)
439
  token_logprob = max(step_top_logprobs[token_id].logprob,
440
  -9999.0)
441
  out_tokens.append(token)
 
448
  out_top_logprobs.append({
449
  # Convert float("-inf") to the
450
  # JSON-serializable float that OpenAI uses
451
+ self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
452
  max(top_lp[1].logprob, -9999.0)
453
  for i, top_lp in enumerate(step_top_logprobs.items())
454
  if num_output_top_logprobs >= i
 
465
  token_logprobs=out_token_logprobs,
466
  tokens=out_tokens,
467
  top_logprobs=out_top_logprobs,
468
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serving_embedding.py CHANGED
@@ -1,16 +1,16 @@
1
  import base64
2
  import time
3
- from typing import AsyncIterator, List, Optional, Tuple
4
 
5
  import numpy as np
6
  from fastapi import Request
7
 
8
  from vllm.config import ModelConfig
9
  from vllm.engine.async_llm_engine import AsyncLLMEngine
 
10
  from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
11
  EmbeddingResponse,
12
  EmbeddingResponseData, UsageInfo)
13
- from vllm.entrypoints.openai.serving_completion import parse_prompt_format
14
  from vllm.entrypoints.openai.serving_engine import OpenAIServing
15
  from vllm.logger import init_logger
16
  from vllm.outputs import EmbeddingRequestOutput
@@ -28,11 +28,11 @@ def request_output_to_embedding_response(
28
  data: List[EmbeddingResponseData] = []
29
  num_prompt_tokens = 0
30
  for idx, final_res in enumerate(final_res_batch):
31
- assert final_res is not None
32
  prompt_token_ids = final_res.prompt_token_ids
33
  embedding = final_res.outputs.embedding
34
  if encoding_format == "base64":
35
- embedding = base64.b64encode(np.array(embedding))
 
36
  embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
37
  data.append(embedding_data)
38
 
@@ -54,12 +54,20 @@ def request_output_to_embedding_response(
54
 
55
  class OpenAIServingEmbedding(OpenAIServing):
56
 
57
- def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
58
- served_model_names: List[str]):
 
 
 
 
 
 
59
  super().__init__(engine=engine,
60
  model_config=model_config,
61
  served_model_names=served_model_names,
62
- lora_modules=None)
 
 
63
  self._check_embedding_mode(model_config.embedding_mode)
64
 
65
  async def create_embedding(self, request: EmbeddingRequest,
@@ -80,32 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
80
  "dimensions is currently not supported")
81
 
82
  model_name = request.model
83
- request_id = f"cmpl-{random_uuid()}"
84
  created_time = int(time.monotonic())
85
 
86
  # Schedule the request and get the result generator.
87
- generators = []
88
  try:
89
- prompt_is_tokens, prompts = parse_prompt_format(request.input)
 
 
 
 
 
 
90
  pooling_params = request.to_pooling_params()
91
 
92
- for i, prompt in enumerate(prompts):
93
- if prompt_is_tokens:
94
- prompt_formats = self._validate_prompt_and_tokenize(
95
- request, prompt_ids=prompt)
96
- else:
97
- prompt_formats = self._validate_prompt_and_tokenize(
98
- request, prompt=prompt)
 
 
 
 
 
 
 
 
99
 
100
- prompt_ids, prompt_text = prompt_formats
 
 
 
101
 
102
  generator = self.engine.encode(
103
- {
104
- "prompt": prompt_text,
105
- "prompt_token_ids": prompt_ids
106
- },
107
  pooling_params,
108
- f"{request_id}-{i}",
 
109
  )
110
 
111
  generators.append(generator)
@@ -124,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
124
  if await raw_request.is_disconnected():
125
  # Abort the request if the client disconnects.
126
  await self.engine.abort(f"{request_id}-{i}")
127
- # TODO: Use a vllm-specific Validation Error
128
  return self.create_error_response("Client disconnected")
129
  final_res_batch[i] = res
 
 
 
 
 
 
 
130
  response = request_output_to_embedding_response(
131
- final_res_batch, request_id, created_time, model_name,
132
  encoding_format)
133
  except ValueError as e:
134
  # TODO: Use a vllm-specific Validation Error
 
1
  import base64
2
  import time
3
+ from typing import AsyncIterator, List, Optional, Tuple, cast
4
 
5
  import numpy as np
6
  from fastapi import Request
7
 
8
  from vllm.config import ModelConfig
9
  from vllm.engine.async_llm_engine import AsyncLLMEngine
10
+ from vllm.entrypoints.logger import RequestLogger
11
  from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
12
  EmbeddingResponse,
13
  EmbeddingResponseData, UsageInfo)
 
14
  from vllm.entrypoints.openai.serving_engine import OpenAIServing
15
  from vllm.logger import init_logger
16
  from vllm.outputs import EmbeddingRequestOutput
 
28
  data: List[EmbeddingResponseData] = []
29
  num_prompt_tokens = 0
30
  for idx, final_res in enumerate(final_res_batch):
 
31
  prompt_token_ids = final_res.prompt_token_ids
32
  embedding = final_res.outputs.embedding
33
  if encoding_format == "base64":
34
+ embedding_bytes = np.array(embedding).tobytes()
35
+ embedding = base64.b64encode(embedding_bytes).decode("utf-8")
36
  embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
37
  data.append(embedding_data)
38
 
 
54
 
55
  class OpenAIServingEmbedding(OpenAIServing):
56
 
57
+ def __init__(
58
+ self,
59
+ engine: AsyncLLMEngine,
60
+ model_config: ModelConfig,
61
+ served_model_names: List[str],
62
+ *,
63
+ request_logger: Optional[RequestLogger],
64
+ ):
65
  super().__init__(engine=engine,
66
  model_config=model_config,
67
  served_model_names=served_model_names,
68
+ lora_modules=None,
69
+ prompt_adapters=None,
70
+ request_logger=request_logger)
71
  self._check_embedding_mode(model_config.embedding_mode)
72
 
73
  async def create_embedding(self, request: EmbeddingRequest,
 
88
  "dimensions is currently not supported")
89
 
90
  model_name = request.model
91
+ request_id = f"embd-{random_uuid()}"
92
  created_time = int(time.monotonic())
93
 
94
  # Schedule the request and get the result generator.
95
+ generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
96
  try:
97
+ (
98
+ lora_request,
99
+ prompt_adapter_request,
100
+ ) = self._maybe_get_adapters(request)
101
+
102
+ tokenizer = await self.engine.get_tokenizer(lora_request)
103
+
104
  pooling_params = request.to_pooling_params()
105
 
106
+ prompts = list(
107
+ self._tokenize_prompt_input_or_inputs(
108
+ request,
109
+ tokenizer,
110
+ request.input,
111
+ ))
112
+
113
+ for i, prompt_inputs in enumerate(prompts):
114
+ request_id_item = f"{request_id}-{i}"
115
+
116
+ self._log_inputs(request_id_item,
117
+ prompt_inputs,
118
+ params=pooling_params,
119
+ lora_request=lora_request,
120
+ prompt_adapter_request=prompt_adapter_request)
121
 
122
+ if prompt_adapter_request is not None:
123
+ raise NotImplementedError(
124
+ "Prompt adapter is not supported "
125
+ "for embedding models")
126
 
127
  generator = self.engine.encode(
128
+ {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
 
 
 
129
  pooling_params,
130
+ request_id_item,
131
+ lora_request=lora_request,
132
  )
133
 
134
  generators.append(generator)
 
147
  if await raw_request.is_disconnected():
148
  # Abort the request if the client disconnects.
149
  await self.engine.abort(f"{request_id}-{i}")
 
150
  return self.create_error_response("Client disconnected")
151
  final_res_batch[i] = res
152
+
153
+ for final_res in final_res_batch:
154
+ assert final_res is not None
155
+
156
+ final_res_batch_checked = cast(List[EmbeddingRequestOutput],
157
+ final_res_batch)
158
+
159
  response = request_output_to_embedding_response(
160
+ final_res_batch_checked, request_id, created_time, model_name,
161
  encoding_format)
162
  except ValueError as e:
163
  # TODO: Use a vllm-specific Validation Error
serving_engine.py CHANGED
@@ -1,65 +1,108 @@
1
  import json
 
2
  from dataclasses import dataclass
3
  from http import HTTPStatus
4
- from typing import Any, Dict, List, Optional, Tuple, Union
5
 
6
  from pydantic import Field
 
7
  from typing_extensions import Annotated
8
 
9
  from vllm.config import ModelConfig
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
 
 
 
11
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
  CompletionRequest,
13
  DetokenizeRequest,
14
  EmbeddingRequest, ErrorResponse,
15
  ModelCard, ModelList,
16
- ModelPermission, TokenizeRequest)
 
 
 
 
 
17
  from vllm.logger import init_logger
18
  from vllm.lora.request import LoRARequest
 
 
 
19
  from vllm.sequence import Logprob
20
- from vllm.transformers_utils.tokenizer import get_tokenizer
21
 
22
  logger = init_logger(__name__)
23
 
24
 
25
  @dataclass
26
- class LoRAModulePath:
27
  name: str
28
  local_path: str
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class OpenAIServing:
32
 
33
- def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
34
- served_model_names: List[str],
35
- lora_modules: Optional[List[LoRAModulePath]]):
 
 
 
 
 
 
 
36
  super().__init__()
37
 
38
  self.engine = engine
39
  self.model_config = model_config
40
  self.max_model_len = model_config.max_model_len
41
 
42
- # A separate tokenizer to map token IDs to strings.
43
- self.tokenizer = get_tokenizer(
44
- model_config.tokenizer,
45
- tokenizer_mode=model_config.tokenizer_mode,
46
- tokenizer_revision=model_config.tokenizer_revision,
47
- trust_remote_code=model_config.trust_remote_code,
48
- truncation_side="left")
49
-
50
  self.served_model_names = served_model_names
51
 
52
- if lora_modules is None:
53
- self.lora_requests = []
54
- else:
55
  self.lora_requests = [
56
  LoRARequest(
57
  lora_name=lora.name,
58
  lora_int_id=i,
59
- lora_local_path=lora.local_path,
60
  ) for i, lora in enumerate(lora_modules, start=1)
61
  ]
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  async def show_available_models(self) -> ModelList:
64
  """Show available models. Right now we only have one model."""
65
  model_cards = [
@@ -75,7 +118,14 @@ class OpenAIServing:
75
  permission=[ModelPermission()])
76
  for lora in self.lora_requests
77
  ]
 
 
 
 
 
 
78
  model_cards.extend(lora_cards)
 
79
  return ModelList(data=model_cards)
80
 
81
  def create_error_response(
@@ -101,71 +151,82 @@ class OpenAIServing:
101
  return json_str
102
 
103
  async def _check_model(
104
- self, request: Union[ChatCompletionRequest, CompletionRequest,
105
- DetokenizeRequest, EmbeddingRequest,
106
- TokenizeRequest]
107
  ) -> Optional[ErrorResponse]:
108
  if request.model in self.served_model_names:
109
  return None
110
  if request.model in [lora.lora_name for lora in self.lora_requests]:
111
  return None
 
 
 
 
 
112
  return self.create_error_response(
113
  message=f"The model `{request.model}` does not exist.",
114
  err_type="NotFoundError",
115
  status_code=HTTPStatus.NOT_FOUND)
116
 
117
- def _maybe_get_lora(
118
- self, request: Union[CompletionRequest, ChatCompletionRequest,
119
- EmbeddingRequest]
120
- ) -> Optional[LoRARequest]:
121
  if request.model in self.served_model_names:
122
- return None
123
  for lora in self.lora_requests:
124
  if request.model == lora.lora_name:
125
- return lora
 
 
 
126
  # if _check_model has been called earlier, this will be unreachable
127
  raise ValueError(f"The model `{request.model}` does not exist.")
128
 
129
- def _validate_prompt_and_tokenize(
130
- self,
131
- request: Union[ChatCompletionRequest, CompletionRequest,
132
- DetokenizeRequest, EmbeddingRequest,
133
- TokenizeRequest],
134
- prompt: Optional[str] = None,
135
- prompt_ids: Optional[List[int]] = None,
136
- truncate_prompt_tokens: Optional[Annotated[int,
137
- Field(ge=1)]] = None,
138
- add_special_tokens: Optional[bool] = True
139
- ) -> Tuple[List[int], str]:
140
- if not (prompt or prompt_ids):
141
- raise ValueError("Either prompt or prompt_ids should be provided.")
142
- if (prompt and prompt_ids):
143
- raise ValueError(
144
- "Only one of prompt or prompt_ids should be provided.")
145
-
146
- if prompt_ids is None:
147
- # When using OpenAIServingChat for chat completions, for
148
- # most models the special tokens (e.g., BOS) have already
149
- # been added by the chat template. Therefore, we do not
150
- # need to add them again.
151
- # Set add_special_tokens to False (by default) to avoid
152
- # adding the BOS tokens again.
153
- tokenizer_kwargs: Dict[str, Any] = {
154
- "add_special_tokens": add_special_tokens
155
- }
156
- if truncate_prompt_tokens is not None:
157
- tokenizer_kwargs.update({
158
- "truncation": True,
159
- "max_length": truncate_prompt_tokens,
160
- })
161
- input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
162
- elif truncate_prompt_tokens is not None:
163
- input_ids = prompt_ids[-truncate_prompt_tokens:]
164
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  input_ids = prompt_ids
 
 
 
 
166
 
167
- input_text = prompt if prompt is not None else self.tokenizer.decode(
168
- prompt_ids)
 
 
 
 
 
 
169
  token_num = len(input_ids)
170
 
171
  # Note: EmbeddingRequest doesn't have max_tokens
@@ -175,13 +236,16 @@ class OpenAIServing:
175
  f"This model's maximum context length is "
176
  f"{self.max_model_len} tokens. However, you requested "
177
  f"{token_num} tokens in the input for embedding "
178
- f"generation. Please reduce the length of the input.", )
179
- return input_ids, input_text
 
180
 
181
  # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
182
  # and does not require model context length validation
183
- if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
184
- return input_ids, input_text
 
 
185
 
186
  if request.max_tokens is None:
187
  if token_num >= self.max_model_len:
@@ -189,7 +253,7 @@ class OpenAIServing:
189
  f"This model's maximum context length is "
190
  f"{self.max_model_len} tokens. However, you requested "
191
  f"{token_num} tokens in the messages, "
192
- f"Please reduce the length of the messages.", )
193
  request.max_tokens = self.max_model_len - token_num
194
 
195
  if token_num + request.max_tokens > self.max_model_len:
@@ -199,11 +263,132 @@ class OpenAIServing:
199
  f"{request.max_tokens + token_num} tokens "
200
  f"({token_num} in the messages, "
201
  f"{request.max_tokens} in the completion). "
202
- f"Please reduce the length of the messages or completion.", )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  else:
204
- return input_ids, input_text
 
 
 
 
 
 
 
 
 
 
205
 
206
- def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
 
 
 
 
 
207
  if logprob.decoded_token is not None:
208
  return logprob.decoded_token
209
- return self.tokenizer.decode(token_id)
 
1
  import json
2
+ import pathlib
3
  from dataclasses import dataclass
4
  from http import HTTPStatus
5
+ from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
6
 
7
  from pydantic import Field
8
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
9
  from typing_extensions import Annotated
10
 
11
  from vllm.config import ModelConfig
12
  from vllm.engine.async_llm_engine import AsyncLLMEngine
13
+ from vllm.entrypoints.logger import RequestLogger
14
+ # yapf conflicts with isort for this block
15
+ # yapf: disable
16
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
17
  CompletionRequest,
18
  DetokenizeRequest,
19
  EmbeddingRequest, ErrorResponse,
20
  ModelCard, ModelList,
21
+ ModelPermission,
22
+ TokenizeChatRequest,
23
+ TokenizeCompletionRequest,
24
+ TokenizeRequest)
25
+ # yapf: enable
26
+ from vllm.inputs import parse_and_batch_prompt
27
  from vllm.logger import init_logger
28
  from vllm.lora.request import LoRARequest
29
+ from vllm.pooling_params import PoolingParams
30
+ from vllm.prompt_adapter.request import PromptAdapterRequest
31
+ from vllm.sampling_params import SamplingParams
32
  from vllm.sequence import Logprob
 
33
 
34
  logger = init_logger(__name__)
35
 
36
 
37
  @dataclass
38
+ class PromptAdapterPath:
39
  name: str
40
  local_path: str
41
 
42
 
43
+ @dataclass
44
+ class LoRAModulePath:
45
+ name: str
46
+ path: str
47
+
48
+
49
+ AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
50
+ EmbeddingRequest, TokenizeRequest]
51
+
52
+ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
53
+
54
+
55
+ class TextTokensPrompt(TypedDict):
56
+ prompt: str
57
+ prompt_token_ids: List[int]
58
+
59
+
60
  class OpenAIServing:
61
 
62
+ def __init__(
63
+ self,
64
+ engine: AsyncLLMEngine,
65
+ model_config: ModelConfig,
66
+ served_model_names: List[str],
67
+ *,
68
+ lora_modules: Optional[List[LoRAModulePath]],
69
+ prompt_adapters: Optional[List[PromptAdapterPath]],
70
+ request_logger: Optional[RequestLogger],
71
+ ):
72
  super().__init__()
73
 
74
  self.engine = engine
75
  self.model_config = model_config
76
  self.max_model_len = model_config.max_model_len
77
 
 
 
 
 
 
 
 
 
78
  self.served_model_names = served_model_names
79
 
80
+ self.lora_requests = []
81
+ if lora_modules is not None:
 
82
  self.lora_requests = [
83
  LoRARequest(
84
  lora_name=lora.name,
85
  lora_int_id=i,
86
+ lora_path=lora.path,
87
  ) for i, lora in enumerate(lora_modules, start=1)
88
  ]
89
 
90
+ self.prompt_adapter_requests = []
91
+ if prompt_adapters is not None:
92
+ for i, prompt_adapter in enumerate(prompt_adapters, start=1):
93
+ with pathlib.Path(prompt_adapter.local_path,
94
+ "adapter_config.json").open() as f:
95
+ adapter_config = json.load(f)
96
+ num_virtual_tokens = adapter_config["num_virtual_tokens"]
97
+ self.prompt_adapter_requests.append(
98
+ PromptAdapterRequest(
99
+ prompt_adapter_name=prompt_adapter.name,
100
+ prompt_adapter_id=i,
101
+ prompt_adapter_local_path=prompt_adapter.local_path,
102
+ prompt_adapter_num_virtual_tokens=num_virtual_tokens))
103
+
104
+ self.request_logger = request_logger
105
+
106
  async def show_available_models(self) -> ModelList:
107
  """Show available models. Right now we only have one model."""
108
  model_cards = [
 
118
  permission=[ModelPermission()])
119
  for lora in self.lora_requests
120
  ]
121
+ prompt_adapter_cards = [
122
+ ModelCard(id=prompt_adapter.prompt_adapter_name,
123
+ root=self.served_model_names[0],
124
+ permission=[ModelPermission()])
125
+ for prompt_adapter in self.prompt_adapter_requests
126
+ ]
127
  model_cards.extend(lora_cards)
128
+ model_cards.extend(prompt_adapter_cards)
129
  return ModelList(data=model_cards)
130
 
131
  def create_error_response(
 
151
  return json_str
152
 
153
  async def _check_model(
154
+ self,
155
+ request: AnyRequest,
 
156
  ) -> Optional[ErrorResponse]:
157
  if request.model in self.served_model_names:
158
  return None
159
  if request.model in [lora.lora_name for lora in self.lora_requests]:
160
  return None
161
+ if request.model in [
162
+ prompt_adapter.prompt_adapter_name
163
+ for prompt_adapter in self.prompt_adapter_requests
164
+ ]:
165
+ return None
166
  return self.create_error_response(
167
  message=f"The model `{request.model}` does not exist.",
168
  err_type="NotFoundError",
169
  status_code=HTTPStatus.NOT_FOUND)
170
 
171
+ def _maybe_get_adapters(
172
+ self, request: AnyRequest
173
+ ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
174
+ None, PromptAdapterRequest]]:
175
  if request.model in self.served_model_names:
176
+ return None, None
177
  for lora in self.lora_requests:
178
  if request.model == lora.lora_name:
179
+ return lora, None
180
+ for prompt_adapter in self.prompt_adapter_requests:
181
+ if request.model == prompt_adapter.prompt_adapter_name:
182
+ return None, prompt_adapter
183
  # if _check_model has been called earlier, this will be unreachable
184
  raise ValueError(f"The model `{request.model}` does not exist.")
185
 
186
+ def _normalize_prompt_text_to_input(
187
+ self,
188
+ request: AnyRequest,
189
+ tokenizer: AnyTokenizer,
190
+ prompt: str,
191
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
192
+ add_special_tokens: bool,
193
+ ) -> TextTokensPrompt:
194
+ if truncate_prompt_tokens is None:
195
+ encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  else:
197
+ encoded = tokenizer(prompt,
198
+ add_special_tokens=add_special_tokens,
199
+ truncation=True,
200
+ max_length=truncate_prompt_tokens)
201
+
202
+ input_ids = encoded.input_ids
203
+
204
+ input_text = prompt
205
+
206
+ return self._validate_input(request, input_ids, input_text)
207
+
208
+ def _normalize_prompt_tokens_to_input(
209
+ self,
210
+ request: AnyRequest,
211
+ tokenizer: AnyTokenizer,
212
+ prompt_ids: List[int],
213
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
214
+ ) -> TextTokensPrompt:
215
+ if truncate_prompt_tokens is None:
216
  input_ids = prompt_ids
217
+ else:
218
+ input_ids = prompt_ids[-truncate_prompt_tokens:]
219
+
220
+ input_text = tokenizer.decode(input_ids)
221
 
222
+ return self._validate_input(request, input_ids, input_text)
223
+
224
+ def _validate_input(
225
+ self,
226
+ request: AnyRequest,
227
+ input_ids: List[int],
228
+ input_text: str,
229
+ ) -> TextTokensPrompt:
230
  token_num = len(input_ids)
231
 
232
  # Note: EmbeddingRequest doesn't have max_tokens
 
236
  f"This model's maximum context length is "
237
  f"{self.max_model_len} tokens. However, you requested "
238
  f"{token_num} tokens in the input for embedding "
239
+ f"generation. Please reduce the length of the input.")
240
+ return TextTokensPrompt(prompt=input_text,
241
+ prompt_token_ids=input_ids)
242
 
243
  # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
244
  # and does not require model context length validation
245
+ if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
246
+ DetokenizeRequest)):
247
+ return TextTokensPrompt(prompt=input_text,
248
+ prompt_token_ids=input_ids)
249
 
250
  if request.max_tokens is None:
251
  if token_num >= self.max_model_len:
 
253
  f"This model's maximum context length is "
254
  f"{self.max_model_len} tokens. However, you requested "
255
  f"{token_num} tokens in the messages, "
256
+ f"Please reduce the length of the messages.")
257
  request.max_tokens = self.max_model_len - token_num
258
 
259
  if token_num + request.max_tokens > self.max_model_len:
 
263
  f"{request.max_tokens + token_num} tokens "
264
  f"({token_num} in the messages, "
265
  f"{request.max_tokens} in the completion). "
266
+ f"Please reduce the length of the messages or completion.")
267
+
268
+ return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
269
+
270
+ def _tokenize_prompt_input(
271
+ self,
272
+ request: AnyRequest,
273
+ tokenizer: AnyTokenizer,
274
+ prompt_input: Union[str, List[int]],
275
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
276
+ add_special_tokens: bool = True,
277
+ ) -> TextTokensPrompt:
278
+ """
279
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
280
+ that assumes single input.
281
+ """
282
+ return next(
283
+ self._tokenize_prompt_inputs(
284
+ request,
285
+ tokenizer,
286
+ [prompt_input],
287
+ truncate_prompt_tokens=truncate_prompt_tokens,
288
+ add_special_tokens=add_special_tokens,
289
+ ))
290
+
291
+ def _tokenize_prompt_inputs(
292
+ self,
293
+ request: AnyRequest,
294
+ tokenizer: AnyTokenizer,
295
+ prompt_inputs: Iterable[Union[str, List[int]]],
296
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
297
+ add_special_tokens: bool = True,
298
+ ) -> Iterator[TextTokensPrompt]:
299
+ """
300
+ A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
301
+ that assumes multiple inputs.
302
+ """
303
+ for text in prompt_inputs:
304
+ if isinstance(text, str):
305
+ yield self._normalize_prompt_text_to_input(
306
+ request,
307
+ tokenizer,
308
+ prompt=text,
309
+ truncate_prompt_tokens=truncate_prompt_tokens,
310
+ add_special_tokens=add_special_tokens,
311
+ )
312
+ else:
313
+ yield self._normalize_prompt_tokens_to_input(
314
+ request,
315
+ tokenizer,
316
+ prompt_ids=text,
317
+ truncate_prompt_tokens=truncate_prompt_tokens,
318
+ )
319
+
320
+ def _tokenize_prompt_input_or_inputs(
321
+ self,
322
+ request: AnyRequest,
323
+ tokenizer: AnyTokenizer,
324
+ input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
325
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
326
+ add_special_tokens: bool = True,
327
+ ) -> Iterator[TextTokensPrompt]:
328
+ """
329
+ Tokenize/detokenize depending on the input format.
330
+
331
+ According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
332
+ , each input can be a string or array of tokens. Note that each request
333
+ can pass one or more inputs.
334
+ """
335
+ for prompt_input in parse_and_batch_prompt(input_or_inputs):
336
+ # Although our type checking is based on mypy,
337
+ # VSCode Pyright extension should still work properly
338
+ # "is True" is required for Pyright to perform type narrowing
339
+ # See: https://github.com/microsoft/pyright/issues/7672
340
+ if prompt_input["is_tokens"] is False:
341
+ yield self._normalize_prompt_text_to_input(
342
+ request,
343
+ tokenizer,
344
+ prompt=prompt_input["content"],
345
+ truncate_prompt_tokens=truncate_prompt_tokens,
346
+ add_special_tokens=add_special_tokens,
347
+ )
348
+ else:
349
+ yield self._normalize_prompt_tokens_to_input(
350
+ request,
351
+ tokenizer,
352
+ prompt_ids=prompt_input["content"],
353
+ truncate_prompt_tokens=truncate_prompt_tokens,
354
+ )
355
+
356
+ def _log_inputs(
357
+ self,
358
+ request_id: str,
359
+ inputs: Union[str, List[int], TextTokensPrompt],
360
+ params: Optional[Union[SamplingParams, PoolingParams]],
361
+ lora_request: Optional[LoRARequest],
362
+ prompt_adapter_request: Optional[PromptAdapterRequest],
363
+ ) -> None:
364
+ if self.request_logger is None:
365
+ return
366
+
367
+ if isinstance(inputs, str):
368
+ prompt = inputs
369
+ prompt_token_ids = None
370
+ elif isinstance(inputs, list):
371
+ prompt = None
372
+ prompt_token_ids = inputs
373
  else:
374
+ prompt = inputs["prompt"]
375
+ prompt_token_ids = inputs["prompt_token_ids"]
376
+
377
+ self.request_logger.log_inputs(
378
+ request_id,
379
+ prompt,
380
+ prompt_token_ids,
381
+ params=params,
382
+ lora_request=lora_request,
383
+ prompt_adapter_request=prompt_adapter_request,
384
+ )
385
 
386
+ @staticmethod
387
+ def _get_decoded_token(
388
+ logprob: Logprob,
389
+ token_id: int,
390
+ tokenizer: AnyTokenizer,
391
+ ) -> str:
392
  if logprob.decoded_token is not None:
393
  return logprob.decoded_token
394
+ return tokenizer.decode(token_id)
serving_tokenization.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ from vllm.config import ModelConfig
4
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
5
+ # yapf conflicts with isort for this block
6
+ # yapf: disable
7
+ from vllm.entrypoints.chat_utils import (ConversationMessage,
8
+ load_chat_template,
9
+ parse_chat_message_content)
10
+ from vllm.entrypoints.logger import RequestLogger
11
+ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
12
+ DetokenizeResponse,
13
+ ErrorResponse,
14
+ TokenizeChatRequest,
15
+ TokenizeRequest,
16
+ TokenizeResponse)
17
+ # yapf: enable
18
+ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
19
+ OpenAIServing)
20
+ from vllm.utils import random_uuid
21
+
22
+
23
+ class OpenAIServingTokenization(OpenAIServing):
24
+
25
+ def __init__(
26
+ self,
27
+ engine: AsyncLLMEngine,
28
+ model_config: ModelConfig,
29
+ served_model_names: List[str],
30
+ *,
31
+ lora_modules: Optional[List[LoRAModulePath]],
32
+ request_logger: Optional[RequestLogger],
33
+ chat_template: Optional[str],
34
+ ):
35
+ super().__init__(engine=engine,
36
+ model_config=model_config,
37
+ served_model_names=served_model_names,
38
+ lora_modules=lora_modules,
39
+ prompt_adapters=None,
40
+ request_logger=request_logger)
41
+
42
+ # If this is None we use the tokenizer's default chat template
43
+ self.chat_template = load_chat_template(chat_template)
44
+
45
+ async def create_tokenize(
46
+ self,
47
+ request: TokenizeRequest,
48
+ ) -> Union[TokenizeResponse, ErrorResponse]:
49
+ error_check_ret = await self._check_model(request)
50
+ if error_check_ret is not None:
51
+ return error_check_ret
52
+
53
+ request_id = f"tokn-{random_uuid()}"
54
+
55
+ (
56
+ lora_request,
57
+ prompt_adapter_request,
58
+ ) = self._maybe_get_adapters(request)
59
+
60
+ tokenizer = await self.engine.get_tokenizer(lora_request)
61
+
62
+ if isinstance(request, TokenizeChatRequest):
63
+ model_config = self.model_config
64
+
65
+ conversation: List[ConversationMessage] = []
66
+
67
+ for message in request.messages:
68
+ result = parse_chat_message_content(message, model_config,
69
+ tokenizer)
70
+ conversation.extend(result.messages)
71
+
72
+ prompt = tokenizer.apply_chat_template(
73
+ add_generation_prompt=request.add_generation_prompt,
74
+ conversation=conversation,
75
+ tokenize=False,
76
+ chat_template=self.chat_template)
77
+ assert isinstance(prompt, str)
78
+ else:
79
+ prompt = request.prompt
80
+
81
+ self._log_inputs(request_id,
82
+ prompt,
83
+ params=None,
84
+ lora_request=lora_request,
85
+ prompt_adapter_request=prompt_adapter_request)
86
+
87
+ # Silently ignore prompt adapter since it does not affect tokenization
88
+
89
+ prompt_input = self._tokenize_prompt_input(
90
+ request,
91
+ tokenizer,
92
+ prompt,
93
+ add_special_tokens=request.add_special_tokens,
94
+ )
95
+ input_ids = prompt_input["prompt_token_ids"]
96
+
97
+ return TokenizeResponse(tokens=input_ids,
98
+ count=len(input_ids),
99
+ max_model_len=self.max_model_len)
100
+
101
+ async def create_detokenize(
102
+ self,
103
+ request: DetokenizeRequest,
104
+ ) -> Union[DetokenizeResponse, ErrorResponse]:
105
+ error_check_ret = await self._check_model(request)
106
+ if error_check_ret is not None:
107
+ return error_check_ret
108
+
109
+ request_id = f"tokn-{random_uuid()}"
110
+
111
+ (
112
+ lora_request,
113
+ prompt_adapter_request,
114
+ ) = self._maybe_get_adapters(request)
115
+
116
+ tokenizer = await self.engine.get_tokenizer(lora_request)
117
+
118
+ self._log_inputs(request_id,
119
+ request.tokens,
120
+ params=None,
121
+ lora_request=lora_request,
122
+ prompt_adapter_request=prompt_adapter_request)
123
+
124
+ if prompt_adapter_request is not None:
125
+ raise NotImplementedError("Prompt adapter is not supported "
126
+ "for tokenization")
127
+
128
+ prompt_input = self._tokenize_prompt_input(
129
+ request,
130
+ tokenizer,
131
+ request.tokens,
132
+ )
133
+ input_text = prompt_input["prompt"]
134
+
135
+ return DetokenizeResponse(prompt=input_text)