sofianhw commited on
Commit
75da468
1 Parent(s): b7d4623
Files changed (4) hide show
  1. api_server.py +112 -79
  2. protocol.py +501 -66
  3. serving_completion.py +453 -248
  4. serving_engine.py +143 -76
api_server.py CHANGED
@@ -1,30 +1,44 @@
1
- import argparse
2
  import asyncio
3
- import json
 
 
4
  from contextlib import asynccontextmanager
5
- from aioprometheus import MetricsMiddleware
6
- from aioprometheus.asgi.starlette import metrics
 
7
  import fastapi
8
  import uvicorn
9
- from http import HTTPStatus
10
  from fastapi import Request
11
  from fastapi.exceptions import RequestValidationError
12
  from fastapi.middleware.cors import CORSMiddleware
13
- from fastapi.responses import JSONResponse, StreamingResponse, Response
 
 
14
 
 
 
15
  from vllm.engine.arg_utils import AsyncEngineArgs
16
  from vllm.engine.async_llm_engine import AsyncLLMEngine
17
- from vllm.engine.metrics import add_global_metrics_labels
18
- from protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
 
 
 
 
 
 
19
  from vllm.logger import init_logger
20
- from serving_chat import OpenAIServingChat
21
- from serving_completion import OpenAIServingCompletion
22
 
23
  TIMEOUT_KEEP_ALIVE = 5 # seconds
24
 
25
- openai_serving_chat: OpenAIServingChat = None
26
- openai_serving_completion: OpenAIServingCompletion = None
27
- logger = init_logger(__name__)
 
 
 
 
28
 
29
 
30
  @asynccontextmanager
@@ -36,7 +50,9 @@ async def lifespan(app: fastapi.FastAPI):
36
  await engine.do_log_stats()
37
 
38
  if not engine_args.disable_log_stats:
39
- asyncio.create_task(_force_log())
 
 
40
 
41
  yield
42
 
@@ -45,62 +61,15 @@ app = fastapi.FastAPI(lifespan=lifespan)
45
 
46
 
47
  def parse_args():
48
- parser = argparse.ArgumentParser(
49
- description="vLLM OpenAI-Compatible RESTful API server.")
50
- parser.add_argument("--host", type=str, default=None, help="host name")
51
- parser.add_argument("--port", type=int, default=8000, help="port number")
52
- parser.add_argument("--allow-credentials",
53
- action="store_true",
54
- help="allow credentials")
55
- parser.add_argument("--allowed-origins",
56
- type=json.loads,
57
- default=["*"],
58
- help="allowed origins")
59
- parser.add_argument("--allowed-methods",
60
- type=json.loads,
61
- default=["*"],
62
- help="allowed methods")
63
- parser.add_argument("--allowed-headers",
64
- type=json.loads,
65
- default=["*"],
66
- help="allowed headers")
67
- parser.add_argument("--served-model-name",
68
- type=str,
69
- default=None,
70
- help="The model name used in the API. If not "
71
- "specified, the model name will be the same as "
72
- "the huggingface name.")
73
- parser.add_argument("--chat-template",
74
- type=str,
75
- default=None,
76
- help="The file path to the chat template, "
77
- "or the template in single-line form "
78
- "for the specified model")
79
- parser.add_argument("--response-role",
80
- type=str,
81
- default="assistant",
82
- help="The role name to return if "
83
- "`request.add_generation_prompt=true`.")
84
- parser.add_argument("--ssl-keyfile",
85
- type=str,
86
- default=None,
87
- help="The file path to the SSL key file")
88
- parser.add_argument("--ssl-certfile",
89
- type=str,
90
- default=None,
91
- help="The file path to the SSL cert file")
92
- parser.add_argument(
93
- "--root-path",
94
- type=str,
95
- default=None,
96
- help="FastAPI root_path when app is behind a path based routing proxy")
97
-
98
- parser = AsyncEngineArgs.add_cli_args(parser)
99
  return parser.parse_args()
100
 
101
 
102
- app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
103
- app.add_route("/metrics", metrics) # Exposes HTTP metrics
 
 
 
104
 
105
 
106
  @app.exception_handler(RequestValidationError)
@@ -112,6 +81,7 @@ async def validation_exception_handler(_, exc):
112
  @app.get("/health")
113
  async def health() -> Response:
114
  """Health check."""
 
115
  return Response(status_code=200)
116
 
117
 
@@ -121,6 +91,12 @@ async def show_available_models():
121
  return JSONResponse(content=models.model_dump())
122
 
123
 
 
 
 
 
 
 
124
  @app.post("/api/v1/chat/completions")
125
  async def create_chat_completion(request: ChatCompletionRequest,
126
  raw_request: Request):
@@ -133,6 +109,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
133
  return StreamingResponse(content=generator,
134
  media_type="text/event-stream")
135
  else:
 
136
  return JSONResponse(content=generator.model_dump())
137
 
138
 
@@ -150,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
150
  return JSONResponse(content=generator.model_dump())
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
153
  if __name__ == "__main__":
154
  args = parse_args()
155
 
@@ -161,28 +149,73 @@ if __name__ == "__main__":
161
  allow_headers=args.allowed_headers,
162
  )
163
 
164
- logger.info(f"args: {args}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if args.served_model_name is not None:
167
- served_model = args.served_model_name
168
  else:
169
- served_model = args.model
170
 
171
  engine_args = AsyncEngineArgs.from_cli_args(args)
172
- engine = AsyncLLMEngine.from_engine_args(engine_args)
173
- openai_serving_chat = OpenAIServingChat(engine, served_model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  args.response_role,
 
175
  args.chat_template)
176
- openai_serving_completion = OpenAIServingCompletion(engine, served_model)
177
-
178
- # Register labels for metrics
179
- add_global_metrics_labels(model_name=engine_args.model)
180
-
181
  app.root_path = args.root_path
182
  uvicorn.run(app,
183
  host=args.host,
184
  port=args.port,
185
- log_level="info",
186
  timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
187
  ssl_keyfile=args.ssl_keyfile,
188
- ssl_certfile=args.ssl_certfile)
 
 
 
 
1
  import asyncio
2
+ import importlib
3
+ import inspect
4
+ import re
5
  from contextlib import asynccontextmanager
6
+ from http import HTTPStatus
7
+ from typing import Optional, Set
8
+
9
  import fastapi
10
  import uvicorn
 
11
  from fastapi import Request
12
  from fastapi.exceptions import RequestValidationError
13
  from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
15
+ from prometheus_client import make_asgi_app
16
+ from starlette.routing import Mount
17
 
18
+ import vllm
19
+ import vllm.envs as envs
20
  from vllm.engine.arg_utils import AsyncEngineArgs
21
  from vllm.engine.async_llm_engine import AsyncLLMEngine
22
+ from vllm.entrypoints.openai.cli_args import make_arg_parser
23
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
24
+ ChatCompletionResponse,
25
+ CompletionRequest,
26
+ EmbeddingRequest, ErrorResponse)
27
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
28
+ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
29
+ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
30
  from vllm.logger import init_logger
31
+ from vllm.usage.usage_lib import UsageContext
 
32
 
33
  TIMEOUT_KEEP_ALIVE = 5 # seconds
34
 
35
+ openai_serving_chat: OpenAIServingChat
36
+ openai_serving_completion: OpenAIServingCompletion
37
+ openai_serving_embedding: OpenAIServingEmbedding
38
+
39
+ logger = init_logger('vllm.entrypoints.openai.api_server')
40
+
41
+ _running_tasks: Set[asyncio.Task] = set()
42
 
43
 
44
  @asynccontextmanager
 
50
  await engine.do_log_stats()
51
 
52
  if not engine_args.disable_log_stats:
53
+ task = asyncio.create_task(_force_log())
54
+ _running_tasks.add(task)
55
+ task.add_done_callback(_running_tasks.remove)
56
 
57
  yield
58
 
 
61
 
62
 
63
  def parse_args():
64
+ parser = make_arg_parser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return parser.parse_args()
66
 
67
 
68
+ # Add prometheus asgi middleware to route /metrics requests
69
+ route = Mount("/metrics", make_asgi_app())
70
+ # Workaround for 307 Redirect for /metrics
71
+ route.path_regex = re.compile('^/metrics(?P<path>.*)$')
72
+ app.routes.append(route)
73
 
74
 
75
  @app.exception_handler(RequestValidationError)
 
81
  @app.get("/health")
82
  async def health() -> Response:
83
  """Health check."""
84
+ await openai_serving_chat.engine.check_health()
85
  return Response(status_code=200)
86
 
87
 
 
91
  return JSONResponse(content=models.model_dump())
92
 
93
 
94
+ @app.get("/version")
95
+ async def show_version():
96
+ ver = {"version": vllm.__version__}
97
+ return JSONResponse(content=ver)
98
+
99
+
100
  @app.post("/api/v1/chat/completions")
101
  async def create_chat_completion(request: ChatCompletionRequest,
102
  raw_request: Request):
 
109
  return StreamingResponse(content=generator,
110
  media_type="text/event-stream")
111
  else:
112
+ assert isinstance(generator, ChatCompletionResponse)
113
  return JSONResponse(content=generator.model_dump())
114
 
115
 
 
127
  return JSONResponse(content=generator.model_dump())
128
 
129
 
130
+ @app.post("/api/v1/embeddings")
131
+ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
132
+ generator = await openai_serving_embedding.create_embedding(
133
+ request, raw_request)
134
+ if isinstance(generator, ErrorResponse):
135
+ return JSONResponse(content=generator.model_dump(),
136
+ status_code=generator.code)
137
+ else:
138
+ return JSONResponse(content=generator.model_dump())
139
+
140
+
141
  if __name__ == "__main__":
142
  args = parse_args()
143
 
 
149
  allow_headers=args.allowed_headers,
150
  )
151
 
152
+ if token := envs.VLLM_API_KEY or args.api_key:
153
+
154
+ @app.middleware("http")
155
+ async def authentication(request: Request, call_next):
156
+ root_path = "" if args.root_path is None else args.root_path
157
+ if request.method == "OPTIONS":
158
+ return await call_next(request)
159
+ if not request.url.path.startswith(f"{root_path}/v1"):
160
+ return await call_next(request)
161
+ if request.headers.get("Authorization") != "Bearer " + token:
162
+ return JSONResponse(content={"error": "Unauthorized"},
163
+ status_code=401)
164
+ return await call_next(request)
165
+
166
+ for middleware in args.middleware:
167
+ module_path, object_name = middleware.rsplit(".", 1)
168
+ imported = getattr(importlib.import_module(module_path), object_name)
169
+ if inspect.isclass(imported):
170
+ app.add_middleware(imported)
171
+ elif inspect.iscoroutinefunction(imported):
172
+ app.middleware("http")(imported)
173
+ else:
174
+ raise ValueError(f"Invalid middleware {middleware}. "
175
+ f"Must be a function or a class.")
176
+
177
+ logger.info("vLLM API server version %s", vllm.__version__)
178
+ logger.info("args: %s", args)
179
 
180
  if args.served_model_name is not None:
181
+ served_model_names = args.served_model_name
182
  else:
183
+ served_model_names = [args.model]
184
 
185
  engine_args = AsyncEngineArgs.from_cli_args(args)
186
+ engine = AsyncLLMEngine.from_engine_args(
187
+ engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
188
+
189
+ event_loop: Optional[asyncio.AbstractEventLoop]
190
+ try:
191
+ event_loop = asyncio.get_running_loop()
192
+ except RuntimeError:
193
+ event_loop = None
194
+
195
+ if event_loop is not None and event_loop.is_running():
196
+ # If the current is instanced by Ray Serve,
197
+ # there is already a running event loop
198
+ model_config = event_loop.run_until_complete(engine.get_model_config())
199
+ else:
200
+ # When using single vLLM without engine_use_ray
201
+ model_config = asyncio.run(engine.get_model_config())
202
+
203
+ openai_serving_chat = OpenAIServingChat(engine, model_config,
204
+ served_model_names,
205
  args.response_role,
206
+ args.lora_modules,
207
  args.chat_template)
208
+ openai_serving_completion = OpenAIServingCompletion(
209
+ engine, model_config, served_model_names, args.lora_modules)
210
+ openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
211
+ served_model_names)
 
212
  app.root_path = args.root_path
213
  uvicorn.run(app,
214
  host=args.host,
215
  port=args.port,
216
+ log_level=args.uvicorn_log_level,
217
  timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
218
  ssl_keyfile=args.ssl_keyfile,
219
+ ssl_certfile=args.ssl_certfile,
220
+ ssl_ca_certs=args.ssl_ca_certs,
221
+ ssl_cert_reqs=args.ssl_cert_reqs)
protocol.py CHANGED
@@ -1,15 +1,58 @@
1
  # Adapted from
2
  # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3
  import time
4
- from typing import Dict, List, Literal, Optional, Union
5
 
6
- from pydantic import BaseModel, Field
 
 
 
 
7
 
8
- from vllm.utils import random_uuid
9
  from vllm.sampling_params import SamplingParams
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
- class ErrorResponse(BaseModel):
13
  object: str = "error"
14
  message: str
15
  type: str
@@ -17,7 +60,7 @@ class ErrorResponse(BaseModel):
17
  code: int
18
 
19
 
20
- class ModelPermission(BaseModel):
21
  id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
22
  object: str = "model_permission"
23
  created: int = Field(default_factory=lambda: int(time.time()))
@@ -29,57 +72,176 @@ class ModelPermission(BaseModel):
29
  allow_fine_tuning: bool = False
30
  organization: str = "*"
31
  group: Optional[str] = None
32
- is_blocking: str = False
33
 
34
 
35
- class ModelCard(BaseModel):
36
  id: str
37
  object: str = "model"
38
  created: int = Field(default_factory=lambda: int(time.time()))
39
  owned_by: str = "vllm"
40
  root: Optional[str] = None
41
  parent: Optional[str] = None
 
42
  permission: List[ModelPermission] = Field(default_factory=list)
43
 
44
 
45
- class ModelList(BaseModel):
46
  object: str = "list"
47
  data: List[ModelCard] = Field(default_factory=list)
48
 
49
 
50
- class UsageInfo(BaseModel):
51
  prompt_tokens: int = 0
52
  total_tokens: int = 0
53
  completion_tokens: Optional[int] = 0
54
 
55
 
56
- class ChatCompletionRequest(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  model: str
58
- messages: Union[str, List[Dict[str, str]]]
59
- temperature: Optional[float] = 0.7
60
- top_p: Optional[float] = 1.0
61
- n: Optional[int] = 1
62
  max_tokens: Optional[int] = None
 
 
 
 
 
 
63
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
64
  stream: Optional[bool] = False
65
- presence_penalty: Optional[float] = 0.0
66
- frequency_penalty: Optional[float] = 0.0
67
- logit_bias: Optional[Dict[str, float]] = None
 
 
68
  user: Optional[str] = None
69
- # Additional parameters supported by vLLM
 
70
  best_of: Optional[int] = None
 
71
  top_k: Optional[int] = -1
 
 
 
 
72
  ignore_eos: Optional[bool] = False
73
- use_beam_search: Optional[bool] = False
74
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
75
  skip_special_tokens: Optional[bool] = True
76
  spaces_between_special_tokens: Optional[bool] = True
77
- add_generation_prompt: Optional[bool] = True
78
- echo: Optional[bool] = False
79
- repetition_penalty: Optional[float] = 1.0
80
- min_p: Optional[float] = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def to_sampling_params(self) -> SamplingParams:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return SamplingParams(
84
  n=self.n,
85
  presence_penalty=self.presence_penalty,
@@ -88,49 +250,173 @@ class ChatCompletionRequest(BaseModel):
88
  temperature=self.temperature,
89
  top_p=self.top_p,
90
  min_p=self.min_p,
 
91
  stop=self.stop,
92
  stop_token_ids=self.stop_token_ids,
93
  max_tokens=self.max_tokens,
 
 
 
94
  best_of=self.best_of,
95
  top_k=self.top_k,
96
  ignore_eos=self.ignore_eos,
97
  use_beam_search=self.use_beam_search,
 
98
  skip_special_tokens=self.skip_special_tokens,
99
  spaces_between_special_tokens=self.spaces_between_special_tokens,
 
 
 
100
  )
101
 
102
-
103
- class CompletionRequest(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  model: str
105
- # a string, array of strings, array of tokens, or array of token arrays
106
  prompt: Union[List[int], List[List[int]], str, List[str]]
107
- suffix: Optional[str] = None
108
- max_tokens: Optional[int] = 16
109
- temperature: Optional[float] = 1.0
110
- top_p: Optional[float] = 1.0
111
- n: Optional[int] = 1
112
- stream: Optional[bool] = False
113
- logprobs: Optional[int] = None
114
  echo: Optional[bool] = False
115
- stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
116
- presence_penalty: Optional[float] = 0.0
117
  frequency_penalty: Optional[float] = 0.0
118
- best_of: Optional[int] = None
119
  logit_bias: Optional[Dict[str, float]] = None
 
 
 
 
 
 
 
 
 
 
 
 
120
  user: Optional[str] = None
121
- # Additional parameters supported by vLLM
122
- top_k: Optional[int] = -1
123
- ignore_eos: Optional[bool] = False
124
  use_beam_search: Optional[bool] = False
 
 
 
 
 
125
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
 
 
126
  skip_special_tokens: Optional[bool] = True
127
  spaces_between_special_tokens: Optional[bool] = True
128
- repetition_penalty: Optional[float] = 1.0
129
- min_p: Optional[float] = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def to_sampling_params(self):
132
  echo_without_generation = self.echo and self.max_tokens == 0
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return SamplingParams(
135
  n=self.n,
136
  best_of=self.best_of,
@@ -141,33 +427,88 @@ class CompletionRequest(BaseModel):
141
  top_p=self.top_p,
142
  top_k=self.top_k,
143
  min_p=self.min_p,
 
144
  stop=self.stop,
145
  stop_token_ids=self.stop_token_ids,
146
  ignore_eos=self.ignore_eos,
147
  max_tokens=self.max_tokens if not echo_without_generation else 1,
 
148
  logprobs=self.logprobs,
149
  use_beam_search=self.use_beam_search,
 
150
  prompt_logprobs=self.logprobs if self.echo else None,
151
  skip_special_tokens=self.skip_special_tokens,
152
  spaces_between_special_tokens=(self.spaces_between_special_tokens),
 
 
 
 
153
  )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- class LogProbs(BaseModel):
157
  text_offset: List[int] = Field(default_factory=list)
158
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
159
  tokens: List[str] = Field(default_factory=list)
160
- top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
161
 
162
 
163
- class CompletionResponseChoice(BaseModel):
164
  index: int
165
  text: str
166
- logprobs: Optional[LogProbs] = None
167
- finish_reason: Optional[Literal["stop", "length"]] = None
168
-
169
-
170
- class CompletionResponse(BaseModel):
 
 
 
 
 
 
 
171
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
172
  object: str = "text_completion"
173
  created: int = Field(default_factory=lambda: int(time.time()))
@@ -176,14 +517,21 @@ class CompletionResponse(BaseModel):
176
  usage: UsageInfo
177
 
178
 
179
- class CompletionResponseStreamChoice(BaseModel):
180
  index: int
181
  text: str
182
- logprobs: Optional[LogProbs] = None
183
- finish_reason: Optional[Literal["stop", "length"]] = None
184
-
185
-
186
- class CompletionStreamResponse(BaseModel):
 
 
 
 
 
 
 
187
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
188
  object: str = "text_completion"
189
  created: int = Field(default_factory=lambda: int(time.time()))
@@ -192,41 +540,128 @@ class CompletionStreamResponse(BaseModel):
192
  usage: Optional[UsageInfo] = Field(default=None)
193
 
194
 
195
- class ChatMessage(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  role: str
197
  content: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
- class ChatCompletionResponseChoice(BaseModel):
201
  index: int
202
  message: ChatMessage
203
- finish_reason: Optional[Literal["stop", "length"]] = None
 
 
204
 
205
 
206
- class ChatCompletionResponse(BaseModel):
207
  id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
208
- object: str = "chat.completion"
209
  created: int = Field(default_factory=lambda: int(time.time()))
210
  model: str
211
  choices: List[ChatCompletionResponseChoice]
212
  usage: UsageInfo
213
 
214
 
215
- class DeltaMessage(BaseModel):
216
  role: Optional[str] = None
217
  content: Optional[str] = None
 
218
 
219
 
220
- class ChatCompletionResponseStreamChoice(BaseModel):
221
  index: int
222
  delta: DeltaMessage
223
- finish_reason: Optional[Literal["stop", "length"]] = None
 
 
224
 
225
 
226
- class ChatCompletionStreamResponse(BaseModel):
227
  id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
228
- object: str = "chat.completion.chunk"
229
  created: int = Field(default_factory=lambda: int(time.time()))
230
  model: str
231
  choices: List[ChatCompletionResponseStreamChoice]
232
- usage: Optional[UsageInfo] = Field(default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Adapted from
2
  # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3
  import time
4
+ from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
+ import openai.types.chat
7
+ import torch
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ # pydantic needs the TypedDict from typing_extensions
10
+ from typing_extensions import Annotated, Required, TypedDict
11
 
12
+ from vllm.pooling_params import PoolingParams
13
  from vllm.sampling_params import SamplingParams
14
+ from vllm.utils import random_uuid
15
+
16
+
17
+ class CustomChatCompletionContentPartParam(TypedDict, total=False):
18
+ __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
19
+
20
+ type: Required[str]
21
+ """The type of the content part."""
22
+
23
+
24
+ ChatCompletionContentPartParam = Union[
25
+ openai.types.chat.ChatCompletionContentPartParam,
26
+ CustomChatCompletionContentPartParam]
27
+
28
+
29
+ class CustomChatCompletionMessageParam(TypedDict, total=False):
30
+ """Enables custom roles in the Chat Completion API."""
31
+ role: Required[str]
32
+ """The role of the message's author."""
33
+
34
+ content: Union[str, List[ChatCompletionContentPartParam]]
35
+ """The contents of the message."""
36
+
37
+ name: str
38
+ """An optional name for the participant.
39
+
40
+ Provides the model information to differentiate between participants of the
41
+ same role.
42
+ """
43
+
44
+
45
+ ChatCompletionMessageParam = Union[
46
+ openai.types.chat.ChatCompletionMessageParam,
47
+ CustomChatCompletionMessageParam]
48
+
49
+
50
+ class OpenAIBaseModel(BaseModel):
51
+ # OpenAI API does not allow extra fields
52
+ model_config = ConfigDict(extra="forbid")
53
 
54
 
55
+ class ErrorResponse(OpenAIBaseModel):
56
  object: str = "error"
57
  message: str
58
  type: str
 
60
  code: int
61
 
62
 
63
+ class ModelPermission(OpenAIBaseModel):
64
  id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
65
  object: str = "model_permission"
66
  created: int = Field(default_factory=lambda: int(time.time()))
 
72
  allow_fine_tuning: bool = False
73
  organization: str = "*"
74
  group: Optional[str] = None
75
+ is_blocking: bool = False
76
 
77
 
78
+ class ModelCard(OpenAIBaseModel):
79
  id: str
80
  object: str = "model"
81
  created: int = Field(default_factory=lambda: int(time.time()))
82
  owned_by: str = "vllm"
83
  root: Optional[str] = None
84
  parent: Optional[str] = None
85
+ max_model_len: Optional[int] = None
86
  permission: List[ModelPermission] = Field(default_factory=list)
87
 
88
 
89
+ class ModelList(OpenAIBaseModel):
90
  object: str = "list"
91
  data: List[ModelCard] = Field(default_factory=list)
92
 
93
 
94
+ class UsageInfo(OpenAIBaseModel):
95
  prompt_tokens: int = 0
96
  total_tokens: int = 0
97
  completion_tokens: Optional[int] = 0
98
 
99
 
100
+ class ResponseFormat(OpenAIBaseModel):
101
+ # type must be "json_object" or "text"
102
+ type: Literal["text", "json_object"]
103
+
104
+
105
+ class FunctionDefinition(OpenAIBaseModel):
106
+ name: str
107
+ description: Optional[str] = None
108
+ parameters: Optional[Dict[str, Any]] = None
109
+
110
+
111
+ class ChatCompletionToolsParam(OpenAIBaseModel):
112
+ type: Literal["function"] = "function"
113
+ function: FunctionDefinition
114
+
115
+
116
+ class ChatCompletionNamedFunction(OpenAIBaseModel):
117
+ name: str
118
+
119
+
120
+ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
121
+ function: ChatCompletionNamedFunction
122
+ type: Literal["function"] = "function"
123
+
124
+
125
+ class ChatCompletionRequest(OpenAIBaseModel):
126
+ # Ordered by official OpenAI API documentation
127
+ # https://platform.openai.com/docs/api-reference/chat/create
128
+ messages: List[ChatCompletionMessageParam]
129
  model: str
130
+ frequency_penalty: Optional[float] = 0.0
131
+ logit_bias: Optional[Dict[str, float]] = None
132
+ logprobs: Optional[bool] = False
133
+ top_logprobs: Optional[int] = 0
134
  max_tokens: Optional[int] = None
135
+ n: Optional[int] = 1
136
+ presence_penalty: Optional[float] = 0.0
137
+ response_format: Optional[ResponseFormat] = None
138
+ seed: Optional[int] = Field(None,
139
+ ge=torch.iinfo(torch.long).min,
140
+ le=torch.iinfo(torch.long).max)
141
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
142
  stream: Optional[bool] = False
143
+ temperature: Optional[float] = 0.7
144
+ top_p: Optional[float] = 1.0
145
+ tools: Optional[List[ChatCompletionToolsParam]] = None
146
+ tool_choice: Optional[Union[Literal["none"],
147
+ ChatCompletionNamedToolChoiceParam]] = "none"
148
  user: Optional[str] = None
149
+
150
+ # doc: begin-chat-completion-sampling-params
151
  best_of: Optional[int] = None
152
+ use_beam_search: Optional[bool] = False
153
  top_k: Optional[int] = -1
154
+ min_p: Optional[float] = 0.0
155
+ repetition_penalty: Optional[float] = 1.0
156
+ length_penalty: Optional[float] = 1.0
157
+ early_stopping: Optional[bool] = False
158
  ignore_eos: Optional[bool] = False
159
+ min_tokens: Optional[int] = 0
160
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
161
  skip_special_tokens: Optional[bool] = True
162
  spaces_between_special_tokens: Optional[bool] = True
163
+ # doc: end-chat-completion-sampling-params
164
+
165
+ # doc: begin-chat-completion-extra-params
166
+ echo: Optional[bool] = Field(
167
+ default=False,
168
+ description=(
169
+ "If true, the new message will be prepended with the last message "
170
+ "if they belong to the same role."),
171
+ )
172
+ add_generation_prompt: Optional[bool] = Field(
173
+ default=True,
174
+ description=
175
+ ("If true, the generation prompt will be added to the chat template. "
176
+ "This is a parameter used by chat template in tokenizer config of the "
177
+ "model."),
178
+ )
179
+ add_special_tokens: Optional[bool] = Field(
180
+ default=False,
181
+ description=(
182
+ "If true, special tokens (e.g. BOS) will be added to the prompt "
183
+ "on top of what is added by the chat template. "
184
+ "For most models, the chat template takes care of adding the "
185
+ "special tokens so this should be set to False (as is the "
186
+ "default)."),
187
+ )
188
+ include_stop_str_in_output: Optional[bool] = Field(
189
+ default=False,
190
+ description=(
191
+ "Whether to include the stop string in the output. "
192
+ "This is only applied when the stop or stop_token_ids is set."),
193
+ )
194
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
195
+ default=None,
196
+ description=("If specified, the output will follow the JSON schema."),
197
+ )
198
+ guided_regex: Optional[str] = Field(
199
+ default=None,
200
+ description=(
201
+ "If specified, the output will follow the regex pattern."),
202
+ )
203
+ guided_choice: Optional[List[str]] = Field(
204
+ default=None,
205
+ description=(
206
+ "If specified, the output will be exactly one of the choices."),
207
+ )
208
+ guided_grammar: Optional[str] = Field(
209
+ default=None,
210
+ description=(
211
+ "If specified, the output will follow the context free grammar."),
212
+ )
213
+ guided_decoding_backend: Optional[str] = Field(
214
+ default=None,
215
+ description=(
216
+ "If specified, will override the default guided decoding backend "
217
+ "of the server for this specific request. If set, must be either "
218
+ "'outlines' / 'lm-format-enforcer'"))
219
+ guided_whitespace_pattern: Optional[str] = Field(
220
+ default=None,
221
+ description=(
222
+ "If specified, will override the default whitespace pattern "
223
+ "for guided json decoding."))
224
+
225
+ # doc: end-chat-completion-extra-params
226
 
227
  def to_sampling_params(self) -> SamplingParams:
228
+ # We now allow logprobs being true without top_logrobs.
229
+
230
+ logits_processors = None
231
+ if self.logit_bias:
232
+
233
+ def logit_bias_logits_processor(
234
+ token_ids: List[int],
235
+ logits: torch.Tensor) -> torch.Tensor:
236
+ assert self.logit_bias is not None
237
+ for token_id, bias in self.logit_bias.items():
238
+ # Clamp the bias between -100 and 100 per OpenAI API spec
239
+ bias = min(100, max(-100, bias))
240
+ logits[int(token_id)] += bias
241
+ return logits
242
+
243
+ logits_processors = [logit_bias_logits_processor]
244
+
245
  return SamplingParams(
246
  n=self.n,
247
  presence_penalty=self.presence_penalty,
 
250
  temperature=self.temperature,
251
  top_p=self.top_p,
252
  min_p=self.min_p,
253
+ seed=self.seed,
254
  stop=self.stop,
255
  stop_token_ids=self.stop_token_ids,
256
  max_tokens=self.max_tokens,
257
+ min_tokens=self.min_tokens,
258
+ logprobs=self.top_logprobs if self.logprobs else None,
259
+ prompt_logprobs=self.top_logprobs if self.echo else None,
260
  best_of=self.best_of,
261
  top_k=self.top_k,
262
  ignore_eos=self.ignore_eos,
263
  use_beam_search=self.use_beam_search,
264
+ early_stopping=self.early_stopping,
265
  skip_special_tokens=self.skip_special_tokens,
266
  spaces_between_special_tokens=self.spaces_between_special_tokens,
267
+ include_stop_str_in_output=self.include_stop_str_in_output,
268
+ length_penalty=self.length_penalty,
269
+ logits_processors=logits_processors,
270
  )
271
 
272
+ @model_validator(mode="before")
273
+ @classmethod
274
+ def check_guided_decoding_count(cls, data):
275
+ guide_count = sum([
276
+ "guided_json" in data and data["guided_json"] is not None,
277
+ "guided_regex" in data and data["guided_regex"] is not None,
278
+ "guided_choice" in data and data["guided_choice"] is not None
279
+ ])
280
+ # you can only use one kind of guided decoding
281
+ if guide_count > 1:
282
+ raise ValueError(
283
+ "You can only use one kind of guided decoding "
284
+ "('guided_json', 'guided_regex' or 'guided_choice').")
285
+ # you can only either use guided decoding or tools, not both
286
+ if guide_count > 1 and "tool_choice" in data and data[
287
+ "tool_choice"] != "none":
288
+ raise ValueError(
289
+ "You can only either use guided decoding or tools, not both.")
290
+ return data
291
+
292
+ @model_validator(mode="before")
293
+ @classmethod
294
+ def check_tool_choice(cls, data):
295
+ if "tool_choice" in data and data["tool_choice"] != "none":
296
+ if not isinstance(data["tool_choice"], dict):
297
+ raise ValueError("Currently only named tools are supported.")
298
+ if "tools" not in data or data["tools"] is None:
299
+ raise ValueError(
300
+ "When using `tool_choice`, `tools` must be set.")
301
+ return data
302
+
303
+ @model_validator(mode="before")
304
+ @classmethod
305
+ def check_logprobs(cls, data):
306
+ if "top_logprobs" in data and data["top_logprobs"] is not None:
307
+ if "logprobs" not in data or data["logprobs"] is False:
308
+ raise ValueError(
309
+ "when using `top_logprobs`, `logprobs` must be set to true."
310
+ )
311
+ elif not 0 <= data["top_logprobs"] <= 20:
312
+ raise ValueError(
313
+ "`top_logprobs` must be a value in the interval [0, 20].")
314
+ return data
315
+
316
+
317
+ class CompletionRequest(OpenAIBaseModel):
318
+ # Ordered by official OpenAI API documentation
319
+ # https://platform.openai.com/docs/api-reference/completions/create
320
  model: str
 
321
  prompt: Union[List[int], List[List[int]], str, List[str]]
322
+ best_of: Optional[int] = None
 
 
 
 
 
 
323
  echo: Optional[bool] = False
 
 
324
  frequency_penalty: Optional[float] = 0.0
 
325
  logit_bias: Optional[Dict[str, float]] = None
326
+ logprobs: Optional[int] = None
327
+ max_tokens: Optional[int] = 16
328
+ n: int = 1
329
+ presence_penalty: Optional[float] = 0.0
330
+ seed: Optional[int] = Field(None,
331
+ ge=torch.iinfo(torch.long).min,
332
+ le=torch.iinfo(torch.long).max)
333
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
334
+ stream: Optional[bool] = False
335
+ suffix: Optional[str] = None
336
+ temperature: Optional[float] = 1.0
337
+ top_p: Optional[float] = 1.0
338
  user: Optional[str] = None
339
+
340
+ # doc: begin-completion-sampling-params
 
341
  use_beam_search: Optional[bool] = False
342
+ top_k: Optional[int] = -1
343
+ min_p: Optional[float] = 0.0
344
+ repetition_penalty: Optional[float] = 1.0
345
+ length_penalty: Optional[float] = 1.0
346
+ early_stopping: Optional[bool] = False
347
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
348
+ ignore_eos: Optional[bool] = False
349
+ min_tokens: Optional[int] = 0
350
  skip_special_tokens: Optional[bool] = True
351
  spaces_between_special_tokens: Optional[bool] = True
352
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
353
+ # doc: end-completion-sampling-params
354
+
355
+ # doc: begin-completion-extra-params
356
+ include_stop_str_in_output: Optional[bool] = Field(
357
+ default=False,
358
+ description=(
359
+ "Whether to include the stop string in the output. "
360
+ "This is only applied when the stop or stop_token_ids is set."),
361
+ )
362
+ response_format: Optional[ResponseFormat] = Field(
363
+ default=None,
364
+ description=
365
+ ("Similar to chat completion, this parameter specifies the format of "
366
+ "output. Only {'type': 'json_object'} or {'type': 'text' } is "
367
+ "supported."),
368
+ )
369
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
370
+ default=None,
371
+ description=("If specified, the output will follow the JSON schema."),
372
+ )
373
+ guided_regex: Optional[str] = Field(
374
+ default=None,
375
+ description=(
376
+ "If specified, the output will follow the regex pattern."),
377
+ )
378
+ guided_choice: Optional[List[str]] = Field(
379
+ default=None,
380
+ description=(
381
+ "If specified, the output will be exactly one of the choices."),
382
+ )
383
+ guided_grammar: Optional[str] = Field(
384
+ default=None,
385
+ description=(
386
+ "If specified, the output will follow the context free grammar."),
387
+ )
388
+ guided_decoding_backend: Optional[str] = Field(
389
+ default=None,
390
+ description=(
391
+ "If specified, will override the default guided decoding backend "
392
+ "of the server for this specific request. If set, must be one of "
393
+ "'outlines' / 'lm-format-enforcer'"))
394
+ guided_whitespace_pattern: Optional[str] = Field(
395
+ default=None,
396
+ description=(
397
+ "If specified, will override the default whitespace pattern "
398
+ "for guided json decoding."))
399
+
400
+ # doc: end-completion-extra-params
401
 
402
  def to_sampling_params(self):
403
  echo_without_generation = self.echo and self.max_tokens == 0
404
 
405
+ logits_processors = None
406
+ if self.logit_bias:
407
+
408
+ def logit_bias_logits_processor(
409
+ token_ids: List[int],
410
+ logits: torch.Tensor) -> torch.Tensor:
411
+ assert self.logit_bias is not None
412
+ for token_id, bias in self.logit_bias.items():
413
+ # Clamp the bias between -100 and 100 per OpenAI API spec
414
+ bias = min(100, max(-100, bias))
415
+ logits[int(token_id)] += bias
416
+ return logits
417
+
418
+ logits_processors = [logit_bias_logits_processor]
419
+
420
  return SamplingParams(
421
  n=self.n,
422
  best_of=self.best_of,
 
427
  top_p=self.top_p,
428
  top_k=self.top_k,
429
  min_p=self.min_p,
430
+ seed=self.seed,
431
  stop=self.stop,
432
  stop_token_ids=self.stop_token_ids,
433
  ignore_eos=self.ignore_eos,
434
  max_tokens=self.max_tokens if not echo_without_generation else 1,
435
+ min_tokens=self.min_tokens,
436
  logprobs=self.logprobs,
437
  use_beam_search=self.use_beam_search,
438
+ early_stopping=self.early_stopping,
439
  prompt_logprobs=self.logprobs if self.echo else None,
440
  skip_special_tokens=self.skip_special_tokens,
441
  spaces_between_special_tokens=(self.spaces_between_special_tokens),
442
+ include_stop_str_in_output=self.include_stop_str_in_output,
443
+ length_penalty=self.length_penalty,
444
+ logits_processors=logits_processors,
445
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
446
  )
447
 
448
+ @model_validator(mode="before")
449
+ @classmethod
450
+ def check_guided_decoding_count(cls, data):
451
+ guide_count = sum([
452
+ "guided_json" in data and data["guided_json"] is not None,
453
+ "guided_regex" in data and data["guided_regex"] is not None,
454
+ "guided_choice" in data and data["guided_choice"] is not None
455
+ ])
456
+ if guide_count > 1:
457
+ raise ValueError(
458
+ "You can only use one kind of guided decoding "
459
+ "('guided_json', 'guided_regex' or 'guided_choice').")
460
+ return data
461
+
462
+ @model_validator(mode="before")
463
+ @classmethod
464
+ def check_logprobs(cls, data):
465
+ if "logprobs" in data and data[
466
+ "logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
467
+ raise ValueError(("if passed, `logprobs` must be a value",
468
+ " in the interval [0, 5]."))
469
+ return data
470
+
471
+
472
+ class EmbeddingRequest(BaseModel):
473
+ # Ordered by official OpenAI API documentation
474
+ # https://platform.openai.com/docs/api-reference/embeddings
475
+ model: str
476
+ input: Union[List[int], List[List[int]], str, List[str]]
477
+ encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
478
+ dimensions: Optional[int] = None
479
+ user: Optional[str] = None
480
+
481
+ # doc: begin-embedding-pooling-params
482
+ additional_data: Optional[Any] = None
483
+
484
+ # doc: end-embedding-pooling-params
485
+
486
+ def to_pooling_params(self):
487
+ return PoolingParams(additional_data=self.additional_data)
488
+
489
 
490
+ class CompletionLogProbs(OpenAIBaseModel):
491
  text_offset: List[int] = Field(default_factory=list)
492
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
493
  tokens: List[str] = Field(default_factory=list)
494
+ top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
495
 
496
 
497
+ class CompletionResponseChoice(OpenAIBaseModel):
498
  index: int
499
  text: str
500
+ logprobs: Optional[CompletionLogProbs] = None
501
+ finish_reason: Optional[str] = None
502
+ stop_reason: Optional[Union[int, str]] = Field(
503
+ default=None,
504
+ description=(
505
+ "The stop string or token id that caused the completion "
506
+ "to stop, None if the completion finished for some other reason "
507
+ "including encountering the EOS token"),
508
+ )
509
+
510
+
511
+ class CompletionResponse(OpenAIBaseModel):
512
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
513
  object: str = "text_completion"
514
  created: int = Field(default_factory=lambda: int(time.time()))
 
517
  usage: UsageInfo
518
 
519
 
520
+ class CompletionResponseStreamChoice(OpenAIBaseModel):
521
  index: int
522
  text: str
523
+ logprobs: Optional[CompletionLogProbs] = None
524
+ finish_reason: Optional[str] = None
525
+ stop_reason: Optional[Union[int, str]] = Field(
526
+ default=None,
527
+ description=(
528
+ "The stop string or token id that caused the completion "
529
+ "to stop, None if the completion finished for some other reason "
530
+ "including encountering the EOS token"),
531
+ )
532
+
533
+
534
+ class CompletionStreamResponse(OpenAIBaseModel):
535
  id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
536
  object: str = "text_completion"
537
  created: int = Field(default_factory=lambda: int(time.time()))
 
540
  usage: Optional[UsageInfo] = Field(default=None)
541
 
542
 
543
+ class EmbeddingResponseData(BaseModel):
544
+ index: int
545
+ object: str = "embedding"
546
+ embedding: List[float]
547
+
548
+
549
+ class EmbeddingResponse(BaseModel):
550
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
551
+ object: str = "list"
552
+ created: int = Field(default_factory=lambda: int(time.time()))
553
+ model: str
554
+ data: List[EmbeddingResponseData]
555
+ usage: UsageInfo
556
+
557
+
558
+ class FunctionCall(OpenAIBaseModel):
559
+ name: str
560
+ arguments: str
561
+
562
+
563
+ class ToolCall(OpenAIBaseModel):
564
+ id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
565
+ type: Literal["function"] = "function"
566
+ function: FunctionCall
567
+
568
+
569
+ class ChatMessage(OpenAIBaseModel):
570
  role: str
571
  content: str
572
+ tool_calls: List[ToolCall] = Field(default_factory=list)
573
+
574
+
575
+ class ChatCompletionLogProb(OpenAIBaseModel):
576
+ token: str
577
+ logprob: float = -9999.0
578
+ bytes: Optional[List[int]] = None
579
+
580
+
581
+ class ChatCompletionLogProbsContent(ChatCompletionLogProb):
582
+ top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
583
+
584
+
585
+ class ChatCompletionLogProbs(OpenAIBaseModel):
586
+ content: Optional[List[ChatCompletionLogProbsContent]] = None
587
 
588
 
589
+ class ChatCompletionResponseChoice(OpenAIBaseModel):
590
  index: int
591
  message: ChatMessage
592
+ logprobs: Optional[ChatCompletionLogProbs] = None
593
+ finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
594
+ stop_reason: Optional[Union[int, str]] = None
595
 
596
 
597
+ class ChatCompletionResponse(OpenAIBaseModel):
598
  id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
599
+ object: Literal["chat.completion"] = "chat.completion"
600
  created: int = Field(default_factory=lambda: int(time.time()))
601
  model: str
602
  choices: List[ChatCompletionResponseChoice]
603
  usage: UsageInfo
604
 
605
 
606
+ class DeltaMessage(OpenAIBaseModel):
607
  role: Optional[str] = None
608
  content: Optional[str] = None
609
+ tool_calls: List[ToolCall] = Field(default_factory=list)
610
 
611
 
612
+ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
613
  index: int
614
  delta: DeltaMessage
615
+ logprobs: Optional[ChatCompletionLogProbs] = None
616
+ finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
617
+ stop_reason: Optional[Union[int, str]] = None
618
 
619
 
620
+ class ChatCompletionStreamResponse(OpenAIBaseModel):
621
  id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
622
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
623
  created: int = Field(default_factory=lambda: int(time.time()))
624
  model: str
625
  choices: List[ChatCompletionResponseStreamChoice]
626
+ usage: Optional[UsageInfo] = Field(default=None)
627
+
628
+
629
+ class BatchRequestInput(OpenAIBaseModel):
630
+ """
631
+ The per-line object of the batch input file.
632
+
633
+ NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
634
+ """
635
+
636
+ # A developer-provided per-request id that will be used to match outputs to
637
+ # inputs. Must be unique for each request in a batch.
638
+ custom_id: str
639
+
640
+ # The HTTP method to be used for the request. Currently only POST is
641
+ # supported.
642
+ method: str
643
+
644
+ # The OpenAI API relative URL to be used for the request. Currently
645
+ # /v1/chat/completions is supported.
646
+ url: str
647
+
648
+ # The parameteters of the request.
649
+ body: Union[ChatCompletionRequest, ]
650
+
651
+
652
+ class BatchRequestOutput(OpenAIBaseModel):
653
+ """
654
+ The per-line object of the batch output and error files
655
+ """
656
+
657
+ id: str
658
+
659
+ # A developer-provided per-request id that will be used to match outputs to
660
+ # inputs.
661
+ custom_id: str
662
+
663
+ response: Optional[ChatCompletionResponse]
664
+
665
+ # For requests that failed with a non-HTTP error, this will contain more
666
+ # information on the cause of the failure.
667
+ error: Optional[Any]
serving_completion.py CHANGED
@@ -1,290 +1,495 @@
 
1
  import time
 
 
 
 
 
 
2
  from fastapi import Request
3
- from typing import AsyncGenerator, AsyncIterator
4
- from vllm.logger import init_logger
5
- from vllm.utils import random_uuid
6
  from vllm.engine.async_llm_engine import AsyncLLMEngine
7
- from protocol import (
8
- CompletionRequest,
9
- CompletionResponse,
10
- CompletionResponseChoice,
11
- CompletionResponseStreamChoice,
12
- CompletionStreamResponse,
13
- LogProbs,
14
- UsageInfo,
15
- )
 
 
 
 
16
  from vllm.outputs import RequestOutput
17
- from serving_engine import OpenAIServing
 
18
 
19
  logger = init_logger(__name__)
20
 
21
 
22
- async def completion_stream_generator(
23
- request: CompletionRequest,
24
- result_generator: AsyncIterator[RequestOutput],
25
- echo_without_generation, create_logprobs_fn, request_id, created_time,
26
- model_name) -> AsyncGenerator[str, None]:
27
- previous_texts = [""] * request.n
28
- previous_num_tokens = [0] * request.n
29
- has_echoed = [False] * request.n
30
-
31
- async for res in result_generator:
32
- # TODO: handle client disconnect for streaming
33
- for output in res.outputs:
34
- i = output.index
35
- delta_text = output.text[len(previous_texts[i]):]
36
- token_ids = output.token_ids[previous_num_tokens[i]:]
37
- if request.logprobs is not None:
38
- top_logprobs = output.logprobs[previous_num_tokens[i]:]
39
- else:
40
- top_logprobs = None
41
- offsets = len(previous_texts[i])
42
- if request.echo and not has_echoed[i]:
43
- if not echo_without_generation:
44
- delta_text = res.prompt + delta_text
45
- token_ids = res.prompt_token_ids + token_ids
46
- if top_logprobs:
47
- top_logprobs = res.prompt_logprobs + top_logprobs
48
- else: # only just return the prompt
49
- delta_text = res.prompt
50
- token_ids = res.prompt_token_ids
51
- if top_logprobs:
52
- top_logprobs = res.prompt_logprobs
53
- has_echoed[i] = True
54
- if request.logprobs is not None:
55
- logprobs = create_logprobs_fn(
56
- token_ids=token_ids,
57
- top_logprobs=top_logprobs,
58
- num_output_top_logprobs=request.logprobs,
59
- initial_text_offset=offsets,
60
- )
61
- else:
62
- logprobs = None
63
- previous_texts[i] = output.text
64
- previous_num_tokens[i] = len(output.token_ids)
65
- finish_reason = output.finish_reason
66
- response_json = CompletionStreamResponse(
67
- id=request_id,
68
- created=created_time,
69
- model=model_name,
70
- choices=[
71
- CompletionResponseStreamChoice(
72
- index=i,
73
- text=delta_text,
74
- logprobs=logprobs,
75
- finish_reason=finish_reason,
76
- )
77
- ]).model_dump_json(exclude_unset=True)
78
- yield f"data: {response_json}\n\n"
79
-
80
- if output.finish_reason is not None:
81
- logprobs = LogProbs() if request.logprobs is not None else None
82
- prompt_tokens = len(res.prompt_token_ids)
83
- completion_tokens = len(output.token_ids)
84
- final_usage = UsageInfo(
85
- prompt_tokens=prompt_tokens,
86
- completion_tokens=completion_tokens,
87
- total_tokens=prompt_tokens + completion_tokens,
88
- )
89
- response_json = CompletionStreamResponse(
90
- id=request_id,
91
- created=created_time,
92
- model=model_name,
93
- choices=[
94
- CompletionResponseStreamChoice(
95
- index=i,
96
- text="",
97
- logprobs=logprobs,
98
- finish_reason=output.finish_reason,
99
- )
100
- ],
101
- usage=final_usage,
102
- ).model_dump_json(exclude_unset=True)
103
- yield f"data: {response_json}\n\n"
104
-
105
- yield "data: [DONE]\n\n"
106
-
107
-
108
- def parse_prompt_format(prompt) -> tuple[bool, list]:
109
- # get the prompt, openai supports the following
110
- # "a string, array of strings, array of tokens, or array of token arrays."
111
- prompt_is_tokens = False
112
- prompts = [prompt] # case 1: a string
113
- if isinstance(prompt, list):
114
- if len(prompt) == 0:
115
- raise ValueError("please provide at least one prompt")
116
- elif isinstance(prompt[0], str):
117
- prompt_is_tokens = False
118
- prompts = prompt # case 2: array of strings
119
- elif isinstance(prompt[0], int):
120
- prompt_is_tokens = True
121
- prompts = [prompt] # case 3: array of tokens
122
- elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
123
- prompt_is_tokens = True
124
- prompts = prompt # case 4: array of token arrays
125
  else:
126
- raise ValueError(
127
- "prompt must be a string, array of strings, array of tokens, or array of token arrays"
128
- )
129
- return prompt_is_tokens, prompts
130
-
131
-
132
- def request_output_to_completion_response(final_res: RequestOutput, request,
133
- echo_without_generation,
134
- create_logprobs_fn, request_id,
135
- created_time,
136
- model_name) -> CompletionResponse:
137
- assert final_res is not None
138
- choices = []
139
- prompt_token_ids = final_res.prompt_token_ids
140
- prompt_logprobs = final_res.prompt_logprobs
141
- prompt_text = final_res.prompt
142
- for output in final_res.outputs:
143
- if request.logprobs is not None:
144
- if not echo_without_generation:
145
- token_ids = output.token_ids
146
- top_logprobs = output.logprobs
147
- if request.echo:
148
- token_ids = prompt_token_ids + token_ids
149
- top_logprobs = prompt_logprobs + top_logprobs
150
  else:
151
- token_ids = prompt_token_ids
152
- top_logprobs = prompt_logprobs
153
- logprobs = create_logprobs_fn(
154
- token_ids=token_ids,
155
- top_logprobs=top_logprobs,
156
- num_output_top_logprobs=request.logprobs,
157
- )
158
- else:
159
- logprobs = None
160
- if not echo_without_generation:
161
- output_text = output.text
162
- if request.echo:
163
- output_text = prompt_text + output_text
164
- else:
165
- output_text = prompt_text
166
- choice_data = CompletionResponseChoice(
167
- index=output.index,
168
- text=output_text,
169
- logprobs=logprobs,
170
- finish_reason=output.finish_reason,
171
- )
172
- choices.append(choice_data)
173
 
174
- num_prompt_tokens = len(final_res.prompt_token_ids)
175
- num_generated_tokens = sum(
176
- len(output.token_ids) for output in final_res.outputs)
177
- usage = UsageInfo(
178
- prompt_tokens=num_prompt_tokens,
179
- completion_tokens=num_generated_tokens,
180
- total_tokens=num_prompt_tokens + num_generated_tokens,
181
- )
182
 
183
- return CompletionResponse(
184
- id=request_id,
185
- created=created_time,
186
- model=model_name,
187
- choices=choices,
188
- usage=usage,
189
- )
190
 
 
 
 
 
 
 
191
 
192
- class OpenAIServingCompletion(OpenAIServing):
 
 
 
 
193
 
194
- def __init__(self, engine: AsyncLLMEngine, served_model: str):
195
- super().__init__(engine=engine, served_model=served_model)
196
 
197
- async def create_completion(self, request: CompletionRequest,
198
- raw_request: Request):
 
 
 
 
199
  """Completion API similar to OpenAI's API.
200
 
201
- See https://platform.openai.com/docs/api-reference/completions/create
202
- for the API specification. This API mimics the OpenAI Completion API.
 
203
 
204
- NOTE: Currently we do not support the following features:
205
- - suffix (the language models we currently support do not support
206
- suffix)
207
- - logit_bias (to be supported by vLLM engine)
208
  """
209
  error_check_ret = await self._check_model(request)
210
  if error_check_ret is not None:
211
  return error_check_ret
212
 
213
- # OpenAI API supports echoing the prompt when max_tokens is 0.
214
- echo_without_generation = request.echo and request.max_tokens == 0
215
-
216
- # Return error for unsupported features.
217
- if request.suffix is not None:
218
- return self.create_error_response(
219
- "suffix is not currently supported")
220
- if request.logit_bias is not None and len(request.logit_bias) > 0:
221
- return self.create_error_response(
222
- "logit_bias is not currently supported")
223
-
224
- model_name = request.model
225
- request_id = f"cmpl-{random_uuid()}"
226
- created_time = int(time.monotonic())
227
-
228
- # Schedule the request and get the result generator.
229
  try:
230
- sampling_params = request.to_sampling_params()
231
 
232
- prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
 
233
 
234
- if len(prompts) > 1:
235
- raise ValueError(
236
- "Batching in completion API is not supported.")
237
- prompt = prompts[0]
238
 
239
- if prompt_is_tokens:
240
- input_ids = self._validate_prompt_and_tokenize(
241
- request, prompt_ids=prompt)
242
- else:
243
- input_ids = self._validate_prompt_and_tokenize(request,
244
- prompt=prompt)
 
 
245
 
246
- result_generator = self.engine.generate(None,
247
- sampling_params,
248
- request_id,
249
- prompt_token_ids=input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  except ValueError as e:
251
  return self.create_error_response(str(e))
252
 
253
- # Similar to the OpenAI API, when n != best_of, we do not stream the
254
- # results. In addition, we do not stream the results when use beam search.
255
- stream = (request.stream
256
- and (request.best_of is None or request.n == request.best_of)
257
- and not request.use_beam_search)
258
-
 
 
 
259
  # Streaming response
260
- if stream:
261
- return completion_stream_generator(request, result_generator,
262
- echo_without_generation,
263
- self._create_logprobs,
264
- request_id, created_time,
265
- model_name)
266
-
267
- # Non-streaming response
268
- final_res: RequestOutput = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  async for res in result_generator:
270
- if await raw_request.is_disconnected():
271
  # Abort the request if the client disconnects.
272
  await self.engine.abort(request_id)
273
  return self.create_error_response("Client disconnected")
274
  final_res = res
275
- response = request_output_to_completion_response(
276
- final_res, request, echo_without_generation, self._create_logprobs,
277
- request_id, created_time, model_name)
278
 
279
- # When user requests streaming but we don't stream, we still need to
280
- # return a streaming response with a single event.
281
- if request.stream:
282
- response_json = response.model_dump_json()
 
 
283
 
284
- async def fake_stream_generator() -> AsyncGenerator[str, None]:
285
- yield f"data: {response_json}\n\n"
286
- yield "data: [DONE]\n\n"
 
 
 
 
 
287
 
288
- return fake_stream_generator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
  import time
3
+ from dataclasses import dataclass
4
+ from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
5
+ Optional)
6
+ from typing import Sequence as GenericSequence
7
+ from typing import TypedDict, Union, cast, final
8
+
9
  from fastapi import Request
10
+ from openai.types.chat import ChatCompletionContentPartTextParam
11
+
12
+ from vllm.config import ModelConfig
13
  from vllm.engine.async_llm_engine import AsyncLLMEngine
14
+ from vllm.entrypoints.openai.protocol import (
15
+ ChatCompletionContentPartParam, ChatCompletionLogProb,
16
+ ChatCompletionLogProbs, ChatCompletionLogProbsContent,
17
+ ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
18
+ ChatCompletionRequest, ChatCompletionResponse,
19
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
20
+ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
21
+ FunctionCall, ToolCall, UsageInfo)
22
+ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
23
+ OpenAIServing)
24
+ from vllm.logger import init_logger
25
+ from vllm.model_executor.guided_decoding import (
26
+ get_guided_decoding_logits_processor)
27
  from vllm.outputs import RequestOutput
28
+ from vllm.sequence import Logprob
29
+ from vllm.utils import random_uuid
30
 
31
  logger = init_logger(__name__)
32
 
33
 
34
+ @final # So that it should be compatible with Dict[str, str]
35
+ class ConversationMessage(TypedDict):
36
+ role: str
37
+ content: str
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class ChatMessageParseResult:
42
+ messages: List[ConversationMessage]
43
+
44
+
45
+ class OpenAIServingChat(OpenAIServing):
46
+
47
+ def __init__(self,
48
+ engine: AsyncLLMEngine,
49
+ model_config: ModelConfig,
50
+ served_model_names: List[str],
51
+ response_role: str,
52
+ lora_modules: Optional[List[LoRAModulePath]] = None,
53
+ chat_template: Optional[str] = None):
54
+ super().__init__(engine=engine,
55
+ model_config=model_config,
56
+ served_model_names=served_model_names,
57
+ lora_modules=lora_modules)
58
+
59
+ self.response_role = response_role
60
+ self._load_chat_template(chat_template)
61
+
62
+ def _load_chat_template(self, chat_template: Optional[str]):
63
+ tokenizer = self.tokenizer
64
+
65
+ if chat_template is not None:
66
+ try:
67
+ with open(chat_template, "r") as f:
68
+ tokenizer.chat_template = f.read()
69
+ except OSError as e:
70
+ JINJA_CHARS = "{}\n"
71
+ if not any(c in chat_template for c in JINJA_CHARS):
72
+ msg = (f"The supplied chat template ({chat_template}) "
73
+ f"looks like a file path, but it failed to be "
74
+ f"opened. Reason: {e}")
75
+ raise ValueError(msg) from e
76
+
77
+ # If opening a file fails, set chat template to be args to
78
+ # ensure we decode so our escape are interpreted correctly
79
+ tokenizer.chat_template = codecs.decode(
80
+ chat_template, "unicode_escape")
81
+
82
+ logger.info("Using supplied chat template:\n%s",
83
+ tokenizer.chat_template)
84
+ elif tokenizer.chat_template is not None:
85
+ logger.info("Using default chat template:\n%s",
86
+ tokenizer.chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
+ logger.warning(
89
+ "No chat template provided. Chat API will not work.")
90
+
91
+ def _parse_chat_message_content_parts(
92
+ self,
93
+ role: str,
94
+ parts: Iterable[ChatCompletionContentPartParam],
95
+ ) -> ChatMessageParseResult:
96
+ texts: List[str] = []
97
+
98
+ for _, part in enumerate(parts):
99
+ part_type = part["type"]
100
+ if part_type == "text":
101
+ text = cast(ChatCompletionContentPartTextParam, part)["text"]
102
+
103
+ texts.append(text)
 
 
 
 
 
 
 
 
104
  else:
105
+ raise NotImplementedError(f"Unknown part type: {part_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ messages = [ConversationMessage(role=role, content="\n".join(texts))]
 
 
 
 
 
 
 
108
 
109
+ return ChatMessageParseResult(messages=messages)
 
 
 
 
 
 
110
 
111
+ def _parse_chat_message_content(
112
+ self,
113
+ message: ChatCompletionMessageParam,
114
+ ) -> ChatMessageParseResult:
115
+ role = message["role"]
116
+ content = message.get("content")
117
 
118
+ if content is None:
119
+ return ChatMessageParseResult(messages=[])
120
+ if isinstance(content, str):
121
+ messages = [ConversationMessage(role=role, content=content)]
122
+ return ChatMessageParseResult(messages=messages)
123
 
124
+ return self._parse_chat_message_content_parts(role, content)
 
125
 
126
+ async def create_chat_completion(
127
+ self,
128
+ request: ChatCompletionRequest,
129
+ raw_request: Optional[Request] = None
130
+ ) -> Union[ErrorResponse, AsyncGenerator[str, None],
131
+ ChatCompletionResponse]:
132
  """Completion API similar to OpenAI's API.
133
 
134
+ See https://platform.openai.com/docs/api-reference/chat/create
135
+ for the API specification. This API mimics the OpenAI
136
+ ChatCompletion API.
137
 
138
+ NOTE: Currently we do not support the following feature:
139
+ - function_call (Users should implement this by themselves)
 
 
140
  """
141
  error_check_ret = await self._check_model(request)
142
  if error_check_ret is not None:
143
  return error_check_ret
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  try:
146
+ conversation: List[ConversationMessage] = []
147
 
148
+ for msg in request.messages:
149
+ parsed_msg = self._parse_chat_message_content(msg)
150
 
151
+ conversation.extend(parsed_msg.messages)
 
 
 
152
 
153
+ prompt = self.tokenizer.apply_chat_template(
154
+ conversation=conversation,
155
+ tokenize=False,
156
+ add_generation_prompt=request.add_generation_prompt,
157
+ )
158
+ except Exception as e:
159
+ logger.error("Error in applying chat template from request: %s", e)
160
+ return self.create_error_response(str(e))
161
 
162
+ request_id = f"cmpl-{random_uuid()}"
163
+ try:
164
+ # Tokenize/detokenize depending on prompt format (string/token list)
165
+ prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
166
+ request,
167
+ prompt=prompt,
168
+ add_special_tokens=request.add_special_tokens)
169
+ sampling_params = request.to_sampling_params()
170
+ lora_request = self._maybe_get_lora(request)
171
+ decoding_config = await self.engine.get_decoding_config()
172
+ guided_decoding_backend = request.guided_decoding_backend \
173
+ or decoding_config.guided_decoding_backend
174
+ guided_decode_logits_processor = (
175
+ await get_guided_decoding_logits_processor(
176
+ guided_decoding_backend, request, await
177
+ self.engine.get_tokenizer()))
178
+ if guided_decode_logits_processor:
179
+ if sampling_params.logits_processors is None:
180
+ sampling_params.logits_processors = []
181
+ sampling_params.logits_processors.append(
182
+ guided_decode_logits_processor)
183
  except ValueError as e:
184
  return self.create_error_response(str(e))
185
 
186
+ result_generator = self.engine.generate(
187
+ {
188
+ "prompt": prompt_text,
189
+ "prompt_token_ids": prompt_ids
190
+ },
191
+ sampling_params,
192
+ request_id,
193
+ lora_request,
194
+ )
195
  # Streaming response
196
+ if request.stream:
197
+ return self.chat_completion_stream_generator(
198
+ request, result_generator, request_id, conversation)
199
+ else:
200
+ try:
201
+ return await self.chat_completion_full_generator(
202
+ request, raw_request, result_generator, request_id,
203
+ conversation)
204
+ except ValueError as e:
205
+ # TODO: Use a vllm-specific Validation Error
206
+ return self.create_error_response(str(e))
207
+
208
+ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
209
+ if request.add_generation_prompt:
210
+ return self.response_role
211
+ else:
212
+ return request.messages[-1]["role"]
213
+
214
+ async def chat_completion_stream_generator(
215
+ self, request: ChatCompletionRequest,
216
+ result_generator: AsyncIterator[RequestOutput], request_id: str,
217
+ conversation: List[ConversationMessage]
218
+ ) -> AsyncGenerator[str, None]:
219
+ model_name = self.served_model_names[0]
220
+ created_time = int(time.time())
221
+ chunk_object_type = "chat.completion.chunk"
222
+ first_iteration = True
223
+
224
+ # Send response for each token for each request.n (index)
225
+ assert request.n is not None
226
+ previous_texts = [""] * request.n
227
+ previous_num_tokens = [0] * request.n
228
+ finish_reason_sent = [False] * request.n
229
+ try:
230
+ async for res in result_generator:
231
+ # We need to do it here, because if there are exceptions in
232
+ # the result_generator, it needs to be sent as the FIRST
233
+ # response (by the try...catch).
234
+ if first_iteration:
235
+ # Send first response for each request.n (index) with
236
+ # the role
237
+ role = self.get_chat_request_role(request)
238
+ for i in range(request.n):
239
+ choice_data = ChatCompletionResponseStreamChoice(
240
+ index=i,
241
+ delta=DeltaMessage(role=role),
242
+ logprobs=None,
243
+ finish_reason=None)
244
+ chunk = ChatCompletionStreamResponse(
245
+ id=request_id,
246
+ object=chunk_object_type,
247
+ created=created_time,
248
+ choices=[choice_data],
249
+ model=model_name)
250
+ data = chunk.model_dump_json(exclude_unset=True)
251
+ yield f"data: {data}\n\n"
252
+
253
+ # Send response to echo the input portion of the
254
+ # last message
255
+ if request.echo:
256
+ last_msg_content = ""
257
+ if conversation and conversation[-1].get(
258
+ "content") and conversation[-1].get(
259
+ "role") == role:
260
+ last_msg_content = conversation[-1]["content"]
261
+
262
+ if last_msg_content:
263
+ for i in range(request.n):
264
+ choice_data = (
265
+ ChatCompletionResponseStreamChoice(
266
+ index=i,
267
+ delta=DeltaMessage(
268
+ content=last_msg_content),
269
+ finish_reason=None))
270
+ chunk = ChatCompletionStreamResponse(
271
+ id=request_id,
272
+ object=chunk_object_type,
273
+ created=created_time,
274
+ choices=[choice_data],
275
+ logprobs=None,
276
+ model=model_name)
277
+ data = chunk.model_dump_json(
278
+ exclude_unset=True)
279
+ yield f"data: {data}\n\n"
280
+ first_iteration = False
281
+
282
+ for output in res.outputs:
283
+ i = output.index
284
+
285
+ if finish_reason_sent[i]:
286
+ continue
287
+
288
+ delta_token_ids = output.token_ids[previous_num_tokens[i]:]
289
+ top_logprobs = output.logprobs[
290
+ previous_num_tokens[i]:] if output.logprobs else None
291
+
292
+ if request.logprobs:
293
+ logprobs = self._create_chat_logprobs(
294
+ token_ids=delta_token_ids,
295
+ top_logprobs=top_logprobs,
296
+ num_output_top_logprobs=request.top_logprobs,
297
+ )
298
+ else:
299
+ logprobs = None
300
+
301
+ delta_text = output.text[len(previous_texts[i]):]
302
+ previous_texts[i] = output.text
303
+ previous_num_tokens[i] = len(output.token_ids)
304
+
305
+ if request.tool_choice and type(
306
+ request.tool_choice
307
+ ) is ChatCompletionNamedToolChoiceParam:
308
+ delta_message = DeltaMessage(tool_calls=[
309
+ ToolCall(function=FunctionCall(
310
+ name=request.tool_choice.function.name,
311
+ arguments=delta_text))
312
+ ])
313
+ else:
314
+ delta_message = DeltaMessage(content=delta_text)
315
+
316
+ if output.finish_reason is None:
317
+ # Send token-by-token response for each request.n
318
+
319
+ choice_data = ChatCompletionResponseStreamChoice(
320
+ index=i,
321
+ delta=delta_message,
322
+ logprobs=logprobs,
323
+ finish_reason=None)
324
+ chunk = ChatCompletionStreamResponse(
325
+ id=request_id,
326
+ object=chunk_object_type,
327
+ created=created_time,
328
+ choices=[choice_data],
329
+ model=model_name)
330
+ data = chunk.model_dump_json(exclude_unset=True)
331
+ yield f"data: {data}\n\n"
332
+ else:
333
+ # Send the finish response for each request.n only once
334
+ prompt_tokens = len(res.prompt_token_ids)
335
+ final_usage = UsageInfo(
336
+ prompt_tokens=prompt_tokens,
337
+ completion_tokens=previous_num_tokens[i],
338
+ total_tokens=prompt_tokens +
339
+ previous_num_tokens[i],
340
+ )
341
+ choice_data = ChatCompletionResponseStreamChoice(
342
+ index=i,
343
+ delta=delta_message,
344
+ logprobs=logprobs,
345
+ finish_reason=output.finish_reason,
346
+ stop_reason=output.stop_reason)
347
+ chunk = ChatCompletionStreamResponse(
348
+ id=request_id,
349
+ object=chunk_object_type,
350
+ created=created_time,
351
+ choices=[choice_data],
352
+ model=model_name)
353
+ if final_usage is not None:
354
+ chunk.usage = final_usage
355
+ data = chunk.model_dump_json(exclude_unset=True,
356
+ exclude_none=True)
357
+ yield f"data: {data}\n\n"
358
+ finish_reason_sent[i] = True
359
+ except ValueError as e:
360
+ # TODO: Use a vllm-specific Validation Error
361
+ data = self.create_streaming_error_response(str(e))
362
+ yield f"data: {data}\n\n"
363
+ # Send the final done message after all response.n are finished
364
+ yield "data: [DONE]\n\n"
365
+
366
+ async def chat_completion_full_generator(
367
+ self, request: ChatCompletionRequest, raw_request: Optional[Request],
368
+ result_generator: AsyncIterator[RequestOutput], request_id: str,
369
+ conversation: List[ConversationMessage]
370
+ ) -> Union[ErrorResponse, ChatCompletionResponse]:
371
+
372
+ model_name = self.served_model_names[0]
373
+ created_time = int(time.time())
374
+ final_res: Optional[RequestOutput] = None
375
+
376
  async for res in result_generator:
377
+ if raw_request is not None and await raw_request.is_disconnected():
378
  # Abort the request if the client disconnects.
379
  await self.engine.abort(request_id)
380
  return self.create_error_response("Client disconnected")
381
  final_res = res
382
+ assert final_res is not None
 
 
383
 
384
+ choices = []
385
+
386
+ role = self.get_chat_request_role(request)
387
+ for output in final_res.outputs:
388
+ token_ids = output.token_ids
389
+ top_logprobs = output.logprobs
390
 
391
+ if request.logprobs:
392
+ logprobs = self._create_chat_logprobs(
393
+ token_ids=token_ids,
394
+ top_logprobs=top_logprobs,
395
+ num_output_top_logprobs=request.top_logprobs,
396
+ )
397
+ else:
398
+ logprobs = None
399
 
400
+ if request.tool_choice and type(
401
+ request.tool_choice) is ChatCompletionNamedToolChoiceParam:
402
+ message = ChatMessage(
403
+ role=role,
404
+ content="",
405
+ tool_calls=[
406
+ ToolCall(function=FunctionCall(
407
+ name=request.tool_choice.function.name,
408
+ arguments=output.text))
409
+ ])
410
+ elif not request.tool_choice or request.tool_choice == "none":
411
+ message = ChatMessage(role=role, content=output.text)
412
+
413
+ choice_data = ChatCompletionResponseChoice(
414
+ index=output.index,
415
+ message=message,
416
+ logprobs=logprobs,
417
+ finish_reason=output.finish_reason,
418
+ stop_reason=output.stop_reason)
419
+ choices.append(choice_data)
420
+
421
+ if request.echo:
422
+ last_msg_content = ""
423
+ if conversation and conversation[-1].get(
424
+ "content") and conversation[-1].get("role") == role:
425
+ last_msg_content = conversation[-1]["content"]
426
+
427
+ for choice in choices:
428
+ full_message = last_msg_content + choice.message.content
429
+ choice.message.content = full_message
430
+
431
+ num_prompt_tokens = len(final_res.prompt_token_ids)
432
+ num_generated_tokens = sum(
433
+ len(output.token_ids) for output in final_res.outputs)
434
+ usage = UsageInfo(
435
+ prompt_tokens=num_prompt_tokens,
436
+ completion_tokens=num_generated_tokens,
437
+ total_tokens=num_prompt_tokens + num_generated_tokens,
438
+ )
439
+ response = ChatCompletionResponse(
440
+ id=request_id,
441
+ created=created_time,
442
+ model=model_name,
443
+ choices=choices,
444
+ usage=usage,
445
+ )
446
 
447
+ return response
448
+
449
+ def _get_top_logprobs(
450
+ self, logprobs: Dict[int, Logprob],
451
+ top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
452
+ return [
453
+ ChatCompletionLogProb(
454
+ token=self._get_decoded_token(p[1], p[0]),
455
+ logprob=max(p[1].logprob, -9999.0),
456
+ bytes=list(
457
+ self._get_decoded_token(p[1],
458
+ p[0]).encode("utf-8",
459
+ errors="replace")))
460
+ for i, p in enumerate(logprobs.items())
461
+ if top_logprobs and i < top_logprobs
462
+ ]
463
+
464
+ def _create_chat_logprobs(
465
+ self,
466
+ token_ids: GenericSequence[int],
467
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
468
+ num_output_top_logprobs: Optional[int] = None,
469
+ ) -> ChatCompletionLogProbs:
470
+ """Create OpenAI-style logprobs."""
471
+
472
+ logprobs_content = []
473
+
474
+ for i, token_id in enumerate(token_ids):
475
+ step_top_logprobs = top_logprobs[i]
476
+ if step_top_logprobs is None:
477
+ logprobs_content.append(
478
+ ChatCompletionLogProbsContent(
479
+ token=self.tokenizer.decode(token_id),
480
+ bytes=list(
481
+ self.tokenizer.decode(token_id).encode(
482
+ "utf-8", errors="replace"))))
483
+ else:
484
+ logprobs_content.append(
485
+ ChatCompletionLogProbsContent(
486
+ token=step_top_logprobs[token_id].decoded_token,
487
+ logprob=max(step_top_logprobs[token_id].logprob,
488
+ -9999.0),
489
+ bytes=list(
490
+ step_top_logprobs[token_id].decoded_token.encode(
491
+ "utf-8", errors="replace")),
492
+ top_logprobs=self._get_top_logprobs(
493
+ step_top_logprobs, num_output_top_logprobs)))
494
+
495
+ return ChatCompletionLogProbs(content=logprobs_content)
serving_engine.py CHANGED
@@ -1,92 +1,81 @@
1
- import asyncio
 
2
  from http import HTTPStatus
3
- from typing import Dict, List, Optional, Union
4
- from vllm.logger import init_logger
5
- from vllm.transformers_utils.tokenizer import get_tokenizer
 
 
 
6
  from vllm.engine.async_llm_engine import AsyncLLMEngine
7
- from protocol import (CompletionRequest,
8
- ChatCompletionRequest,
9
- ErrorResponse, LogProbs,
10
  ModelCard, ModelList,
11
  ModelPermission)
 
 
 
 
12
 
13
  logger = init_logger(__name__)
14
 
15
 
16
- class OpenAIServing:
 
 
 
17
 
18
- def __init__(self, engine: AsyncLLMEngine, served_model: str):
19
- self.engine = engine
20
- self.served_model = served_model
21
-
22
- self.max_model_len = 0
23
- self.tokenizer = None
24
 
25
- try:
26
- event_loop = asyncio.get_running_loop()
27
- except RuntimeError:
28
- event_loop = None
29
 
30
- if event_loop is not None and event_loop.is_running(
31
- ): # If the current is instanced by Ray Serve, there is already a running event loop
32
- event_loop.create_task(self._post_init())
33
- else: # When using single vLLM without engine_use_ray
34
- asyncio.run(self._post_init())
35
 
36
- async def _post_init(self):
37
- engine_model_config = await self.engine.get_model_config()
38
- self.max_model_len = engine_model_config.max_model_len
39
 
40
  # A separate tokenizer to map token IDs to strings.
41
  self.tokenizer = get_tokenizer(
42
- engine_model_config.tokenizer,
43
- tokenizer_mode=engine_model_config.tokenizer_mode,
44
- trust_remote_code=engine_model_config.trust_remote_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  async def show_available_models(self) -> ModelList:
47
  """Show available models. Right now we only have one model."""
48
  model_cards = [
49
- ModelCard(id=self.served_model,
50
- root=self.served_model,
 
 
 
 
 
 
 
51
  permission=[ModelPermission()])
 
52
  ]
 
53
  return ModelList(data=model_cards)
54
 
55
- def _create_logprobs(
56
- self,
57
- token_ids: List[int],
58
- top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
59
- num_output_top_logprobs: Optional[int] = None,
60
- initial_text_offset: int = 0,
61
- ) -> LogProbs:
62
- """Create OpenAI-style logprobs."""
63
- logprobs = LogProbs()
64
- last_token_len = 0
65
- if num_output_top_logprobs:
66
- logprobs.top_logprobs = []
67
- for i, token_id in enumerate(token_ids):
68
- step_top_logprobs = top_logprobs[i]
69
- if step_top_logprobs is not None:
70
- token_logprob = step_top_logprobs[token_id]
71
- else:
72
- token_logprob = None
73
- token = self.tokenizer.convert_ids_to_tokens(token_id)
74
- logprobs.tokens.append(token)
75
- logprobs.token_logprobs.append(token_logprob)
76
- if len(logprobs.text_offset) == 0:
77
- logprobs.text_offset.append(initial_text_offset)
78
- else:
79
- logprobs.text_offset.append(logprobs.text_offset[-1] +
80
- last_token_len)
81
- last_token_len = len(token)
82
-
83
- if num_output_top_logprobs:
84
- logprobs.top_logprobs.append({
85
- self.tokenizer.convert_ids_to_tokens(i): p
86
- for i, p in step_top_logprobs.items()
87
- } if step_top_logprobs else None)
88
- return logprobs
89
-
90
  def create_error_response(
91
  self,
92
  message: str,
@@ -96,38 +85,116 @@ class OpenAIServing:
96
  type=err_type,
97
  code=status_code.value)
98
 
99
- async def _check_model(self, request) -> Optional[ErrorResponse]:
100
- if request.model == self.served_model:
101
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  return self.create_error_response(
103
  message=f"The model `{request.model}` does not exist.",
104
  err_type="NotFoundError",
105
  status_code=HTTPStatus.NOT_FOUND)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def _validate_prompt_and_tokenize(
108
  self,
109
- request: Union[ChatCompletionRequest, CompletionRequest],
 
110
  prompt: Optional[str] = None,
111
- prompt_ids: Optional[List[int]] = None) -> List[int]:
 
 
 
 
112
  if not (prompt or prompt_ids):
113
  raise ValueError("Either prompt or prompt_ids should be provided.")
114
  if (prompt and prompt_ids):
115
  raise ValueError(
116
  "Only one of prompt or prompt_ids should be provided.")
117
 
118
- input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
119
- prompt).input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  token_num = len(input_ids)
121
 
 
 
 
 
 
 
 
 
 
 
122
  if request.max_tokens is None:
 
 
 
 
 
 
123
  request.max_tokens = self.max_model_len - token_num
124
 
125
  if token_num + request.max_tokens > self.max_model_len:
126
  raise ValueError(
127
- f"This model's maximum context length is {self.max_model_len} tokens. "
128
- f"However, you requested {request.max_tokens + token_num} tokens "
 
129
  f"({token_num} in the messages, "
130
  f"{request.max_tokens} in the completion). "
131
  f"Please reduce the length of the messages or completion.", )
132
  else:
133
- return input_ids
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
  from http import HTTPStatus
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ from pydantic import Field
7
+ from typing_extensions import Annotated
8
+
9
+ from vllm.config import ModelConfig
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
11
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
+ CompletionRequest,
13
+ EmbeddingRequest, ErrorResponse,
14
  ModelCard, ModelList,
15
  ModelPermission)
16
+ from vllm.logger import init_logger
17
+ from vllm.lora.request import LoRARequest
18
+ from vllm.sequence import Logprob
19
+ from vllm.transformers_utils.tokenizer import get_tokenizer
20
 
21
  logger = init_logger(__name__)
22
 
23
 
24
+ @dataclass
25
+ class LoRAModulePath:
26
+ name: str
27
+ local_path: str
28
 
 
 
 
 
 
 
29
 
30
+ class OpenAIServing:
 
 
 
31
 
32
+ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
33
+ served_model_names: List[str],
34
+ lora_modules: Optional[List[LoRAModulePath]]):
35
+ super().__init__()
 
36
 
37
+ self.engine = engine
38
+ self.max_model_len = model_config.max_model_len
 
39
 
40
  # A separate tokenizer to map token IDs to strings.
41
  self.tokenizer = get_tokenizer(
42
+ model_config.tokenizer,
43
+ tokenizer_mode=model_config.tokenizer_mode,
44
+ tokenizer_revision=model_config.tokenizer_revision,
45
+ trust_remote_code=model_config.trust_remote_code,
46
+ truncation_side="left")
47
+
48
+ self.served_model_names = served_model_names
49
+
50
+ if lora_modules is None:
51
+ self.lora_requests = []
52
+ else:
53
+ self.lora_requests = [
54
+ LoRARequest(
55
+ lora_name=lora.name,
56
+ lora_int_id=i,
57
+ lora_local_path=lora.local_path,
58
+ ) for i, lora in enumerate(lora_modules, start=1)
59
+ ]
60
 
61
  async def show_available_models(self) -> ModelList:
62
  """Show available models. Right now we only have one model."""
63
  model_cards = [
64
+ ModelCard(id=served_model_name,
65
+ max_model_len=self.max_model_len,
66
+ root=self.served_model_names[0],
67
+ permission=[ModelPermission()])
68
+ for served_model_name in self.served_model_names
69
+ ]
70
+ lora_cards = [
71
+ ModelCard(id=lora.lora_name,
72
+ root=self.served_model_names[0],
73
  permission=[ModelPermission()])
74
+ for lora in self.lora_requests
75
  ]
76
+ model_cards.extend(lora_cards)
77
  return ModelList(data=model_cards)
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def create_error_response(
80
  self,
81
  message: str,
 
85
  type=err_type,
86
  code=status_code.value)
87
 
88
+ def create_streaming_error_response(
89
+ self,
90
+ message: str,
91
+ err_type: str = "BadRequestError",
92
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
93
+ json_str = json.dumps({
94
+ "error":
95
+ self.create_error_response(message=message,
96
+ err_type=err_type,
97
+ status_code=status_code).model_dump()
98
+ })
99
+ return json_str
100
+
101
+ async def _check_model(
102
+ self, request: Union[CompletionRequest, ChatCompletionRequest,
103
+ EmbeddingRequest]
104
+ ) -> Optional[ErrorResponse]:
105
+ if request.model in self.served_model_names:
106
+ return None
107
+ if request.model in [lora.lora_name for lora in self.lora_requests]:
108
+ return None
109
  return self.create_error_response(
110
  message=f"The model `{request.model}` does not exist.",
111
  err_type="NotFoundError",
112
  status_code=HTTPStatus.NOT_FOUND)
113
 
114
+ def _maybe_get_lora(
115
+ self, request: Union[CompletionRequest, ChatCompletionRequest,
116
+ EmbeddingRequest]
117
+ ) -> Optional[LoRARequest]:
118
+ if request.model in self.served_model_names:
119
+ return None
120
+ for lora in self.lora_requests:
121
+ if request.model == lora.lora_name:
122
+ return lora
123
+ # if _check_model has been called earlier, this will be unreachable
124
+ raise ValueError(f"The model `{request.model}` does not exist.")
125
+
126
  def _validate_prompt_and_tokenize(
127
  self,
128
+ request: Union[ChatCompletionRequest, CompletionRequest,
129
+ EmbeddingRequest],
130
  prompt: Optional[str] = None,
131
+ prompt_ids: Optional[List[int]] = None,
132
+ truncate_prompt_tokens: Optional[Annotated[int,
133
+ Field(ge=1)]] = None,
134
+ add_special_tokens: Optional[bool] = True
135
+ ) -> Tuple[List[int], str]:
136
  if not (prompt or prompt_ids):
137
  raise ValueError("Either prompt or prompt_ids should be provided.")
138
  if (prompt and prompt_ids):
139
  raise ValueError(
140
  "Only one of prompt or prompt_ids should be provided.")
141
 
142
+ if prompt_ids is None:
143
+ # When using OpenAIServingChat for chat completions, for
144
+ # most models the special tokens (e.g., BOS) have already
145
+ # been added by the chat template. Therefore, we do not
146
+ # need to add them again.
147
+ # Set add_special_tokens to False (by default) to avoid
148
+ # adding the BOS tokens again.
149
+ tokenizer_kwargs: Dict[str, Any] = {
150
+ "add_special_tokens": add_special_tokens
151
+ }
152
+ if truncate_prompt_tokens is not None:
153
+ tokenizer_kwargs.update({
154
+ "truncation": True,
155
+ "max_length": truncate_prompt_tokens,
156
+ })
157
+ input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
158
+ elif truncate_prompt_tokens is not None:
159
+ input_ids = prompt_ids[-truncate_prompt_tokens:]
160
+ else:
161
+ input_ids = prompt_ids
162
+
163
+ input_text = prompt if prompt is not None else self.tokenizer.decode(
164
+ prompt_ids)
165
  token_num = len(input_ids)
166
 
167
+ # Note: EmbeddingRequest doesn't have max_tokens
168
+ if isinstance(request, EmbeddingRequest):
169
+ if token_num > self.max_model_len:
170
+ raise ValueError(
171
+ f"This model's maximum context length is "
172
+ f"{self.max_model_len} tokens. However, you requested "
173
+ f"{token_num} tokens in the input for embedding "
174
+ f"generation. Please reduce the length of the input.", )
175
+ return input_ids, input_text
176
+
177
  if request.max_tokens is None:
178
+ if token_num >= self.max_model_len:
179
+ raise ValueError(
180
+ f"This model's maximum context length is "
181
+ f"{self.max_model_len} tokens. However, you requested "
182
+ f"{token_num} tokens in the messages, "
183
+ f"Please reduce the length of the messages.", )
184
  request.max_tokens = self.max_model_len - token_num
185
 
186
  if token_num + request.max_tokens > self.max_model_len:
187
  raise ValueError(
188
+ f"This model's maximum context length is "
189
+ f"{self.max_model_len} tokens. However, you requested "
190
+ f"{request.max_tokens + token_num} tokens "
191
  f"({token_num} in the messages, "
192
  f"{request.max_tokens} in the completion). "
193
  f"Please reduce the length of the messages or completion.", )
194
  else:
195
+ return input_ids, input_text
196
+
197
+ def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
198
+ if logprob.decoded_token is not None:
199
+ return logprob.decoded_token
200
+ return self.tokenizer.decode(token_id)