sofianhw commited on
Commit
690332d
1 Parent(s): 0a67f63

add api server and openapi

Browse files
Files changed (9) hide show
  1. Dockerfile +1 -0
  2. README.md +1 -1
  3. api_server.py +188 -0
  4. entrypoint.sh +7 -1
  5. main.py +0 -76
  6. protocol.py +232 -0
  7. serving_chat.py +265 -0
  8. serving_completion.py +290 -0
  9. serving_engine.py +133 -0
Dockerfile CHANGED
@@ -14,6 +14,7 @@ RUN pip3 install "torch==2.1.1"
14
  # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
15
  # RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
16
  RUN pip3 install vllm
 
17
  RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
18
 
19
  RUN git clone https://github.com/NVIDIA/apex && \
 
14
  # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
15
  # RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
16
  RUN pip3 install vllm
17
+ RUN pip3 install openai
18
  RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
19
 
20
  RUN git clone https://github.com/NVIDIA/apex && \
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: Test Docker
3
  emoji: 🔥
4
  colorFrom: purple
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
2
  title: Test Docker
3
  emoji: 🔥
4
  colorFrom: purple
5
+ colorTo: white
6
  sdk: docker
7
  pinned: false
8
  license: mit
api_server.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ from contextlib import asynccontextmanager
5
+ from aioprometheus import MetricsMiddleware
6
+ from aioprometheus.asgi.starlette import metrics
7
+ import fastapi
8
+ import uvicorn
9
+ from http import HTTPStatus
10
+ from fastapi import Request
11
+ from fastapi.exceptions import RequestValidationError
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import JSONResponse, StreamingResponse, Response
14
+
15
+ from vllm.engine.arg_utils import AsyncEngineArgs
16
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
17
+ from vllm.engine.metrics import add_global_metrics_labels
18
+ from protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
19
+ from vllm.logger import init_logger
20
+ from serving_chat import OpenAIServingChat
21
+ from serving_completion import OpenAIServingCompletion
22
+
23
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
24
+
25
+ openai_serving_chat: OpenAIServingChat = None
26
+ openai_serving_completion: OpenAIServingCompletion = None
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ @asynccontextmanager
31
+ async def lifespan(app: fastapi.FastAPI):
32
+
33
+ async def _force_log():
34
+ while True:
35
+ await asyncio.sleep(10)
36
+ await engine.do_log_stats()
37
+
38
+ if not engine_args.disable_log_stats:
39
+ asyncio.create_task(_force_log())
40
+
41
+ yield
42
+
43
+
44
+ app = fastapi.FastAPI(lifespan=lifespan)
45
+
46
+
47
+ def parse_args():
48
+ parser = argparse.ArgumentParser(
49
+ description="vLLM OpenAI-Compatible RESTful API server.")
50
+ parser.add_argument("--host", type=str, default=None, help="host name")
51
+ parser.add_argument("--port", type=int, default=8000, help="port number")
52
+ parser.add_argument("--allow-credentials",
53
+ action="store_true",
54
+ help="allow credentials")
55
+ parser.add_argument("--allowed-origins",
56
+ type=json.loads,
57
+ default=["*"],
58
+ help="allowed origins")
59
+ parser.add_argument("--allowed-methods",
60
+ type=json.loads,
61
+ default=["*"],
62
+ help="allowed methods")
63
+ parser.add_argument("--allowed-headers",
64
+ type=json.loads,
65
+ default=["*"],
66
+ help="allowed headers")
67
+ parser.add_argument("--served-model-name",
68
+ type=str,
69
+ default=None,
70
+ help="The model name used in the API. If not "
71
+ "specified, the model name will be the same as "
72
+ "the huggingface name.")
73
+ parser.add_argument("--chat-template",
74
+ type=str,
75
+ default=None,
76
+ help="The file path to the chat template, "
77
+ "or the template in single-line form "
78
+ "for the specified model")
79
+ parser.add_argument("--response-role",
80
+ type=str,
81
+ default="assistant",
82
+ help="The role name to return if "
83
+ "`request.add_generation_prompt=true`.")
84
+ parser.add_argument("--ssl-keyfile",
85
+ type=str,
86
+ default=None,
87
+ help="The file path to the SSL key file")
88
+ parser.add_argument("--ssl-certfile",
89
+ type=str,
90
+ default=None,
91
+ help="The file path to the SSL cert file")
92
+ parser.add_argument(
93
+ "--root-path",
94
+ type=str,
95
+ default=None,
96
+ help="FastAPI root_path when app is behind a path based routing proxy")
97
+
98
+ parser = AsyncEngineArgs.add_cli_args(parser)
99
+ return parser.parse_args()
100
+
101
+
102
+ app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
103
+ app.add_route("/metrics", metrics) # Exposes HTTP metrics
104
+
105
+
106
+ @app.exception_handler(RequestValidationError)
107
+ async def validation_exception_handler(_, exc):
108
+ err = openai_serving_chat.create_error_response(message=str(exc))
109
+ return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
110
+
111
+
112
+ @app.get("/health")
113
+ async def health() -> Response:
114
+ """Health check."""
115
+ return Response(status_code=200)
116
+
117
+
118
+ @app.get("/api/v1/models")
119
+ async def show_available_models():
120
+ models = await openai_serving_chat.show_available_models()
121
+ return JSONResponse(content=models.model_dump())
122
+
123
+
124
+ @app.post("/api/v1/chat/completions")
125
+ async def create_chat_completion(request: ChatCompletionRequest,
126
+ raw_request: Request):
127
+ generator = await openai_serving_chat.create_chat_completion(
128
+ request, raw_request)
129
+ if isinstance(generator, ErrorResponse):
130
+ return JSONResponse(content=generator.model_dump(),
131
+ status_code=generator.code)
132
+ if request.stream:
133
+ return StreamingResponse(content=generator,
134
+ media_type="text/event-stream")
135
+ else:
136
+ return JSONResponse(content=generator.model_dump())
137
+
138
+
139
+ @app.post("/api/v1/completions")
140
+ async def create_completion(request: CompletionRequest, raw_request: Request):
141
+ generator = await openai_serving_completion.create_completion(
142
+ request, raw_request)
143
+ if isinstance(generator, ErrorResponse):
144
+ return JSONResponse(content=generator.model_dump(),
145
+ status_code=generator.code)
146
+ if request.stream:
147
+ return StreamingResponse(content=generator,
148
+ media_type="text/event-stream")
149
+ else:
150
+ return JSONResponse(content=generator.model_dump())
151
+
152
+
153
+ if __name__ == "__main__":
154
+ args = parse_args()
155
+
156
+ app.add_middleware(
157
+ CORSMiddleware,
158
+ allow_origins=args.allowed_origins,
159
+ allow_credentials=args.allow_credentials,
160
+ allow_methods=args.allowed_methods,
161
+ allow_headers=args.allowed_headers,
162
+ )
163
+
164
+ logger.info(f"args: {args}")
165
+
166
+ if args.served_model_name is not None:
167
+ served_model = args.served_model_name
168
+ else:
169
+ served_model = args.model
170
+
171
+ engine_args = AsyncEngineArgs.from_cli_args(args)
172
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
173
+ openai_serving_chat = OpenAIServingChat(engine, served_model,
174
+ args.response_role,
175
+ args.chat_template)
176
+ openai_serving_completion = OpenAIServingCompletion(engine, served_model)
177
+
178
+ # Register labels for metrics
179
+ add_global_metrics_labels(model_name=engine_args.model)
180
+
181
+ app.root_path = args.root_path
182
+ uvicorn.run(app,
183
+ host=args.host,
184
+ port=args.port,
185
+ log_level="info",
186
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
187
+ ssl_keyfile=args.ssl_keyfile,
188
+ ssl_certfile=args.ssl_certfile)
entrypoint.sh CHANGED
@@ -30,7 +30,13 @@ if [[ ! -z "${ROOT_PATH}" ]]; then
30
  fi
31
 
32
  # Run the provided command
33
- exec python3 -u -m vllm.entrypoints.openai.api_server \
 
 
 
 
 
 
34
  --model "${HF_MODEL}" \
35
  --host 0.0.0.0 \
36
  --port 7860 \
 
30
  fi
31
 
32
  # Run the provided command
33
+ # exec python3 -u -m vllm.entrypoints.openai.api_server \
34
+ # --model "${HF_MODEL}" \
35
+ # --host 0.0.0.0 \
36
+ # --port 7860 \
37
+ # ${additional_args}
38
+
39
+ exec python3 -u api_server.py \
40
  --model "${HF_MODEL}" \
41
  --host 0.0.0.0 \
42
  --port 7860 \
main.py DELETED
@@ -1,76 +0,0 @@
1
- import os
2
- import copy
3
- import time
4
- import llama_cpp
5
- from llama_cpp import Llama
6
- from huggingface_hub import hf_hub_download
7
-
8
- import uvicorn
9
- from fastapi import FastAPI, Request
10
-
11
-
12
- llm = Llama(
13
- model_path=hf_hub_download(
14
- repo_id=os.environ.get("REPO_ID", "TheBloke/Llama-2-7b-Chat-GGUF"),
15
- filename=os.environ.get("MODEL_FILE", "llama-2-7b-chat.Q5_0.gguf"),
16
- ),
17
- n_ctx=2048,
18
- n_gpu_layers=50, # change n_gpu_layers if you have more or less VRAM
19
- )
20
-
21
- history = []
22
-
23
- system_message = """
24
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
25
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
26
- """
27
-
28
-
29
- def generate_text(message, history):
30
- temp = ""
31
- input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
32
- for interaction in history:
33
- input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
34
-
35
- input_prompt = input_prompt + str(message) + " [/INST] "
36
-
37
- output = llm(
38
- input_prompt,
39
- temperature=0.15,
40
- top_p=0.1,
41
- top_k=40,
42
- repeat_penalty=1.1,
43
- max_tokens=1024,
44
- stop=[
45
- "<|prompter|>",
46
- "<|endoftext|>",
47
- "<|endoftext|> \n",
48
- "ASSISTANT:",
49
- "USER:",
50
- "SYSTEM:",
51
- ],
52
- )
53
- # for out in output:
54
- # stream = copy.deepcopy(out)
55
- # temp += stream["choices"][0]["text"]
56
- # yield temp
57
-
58
- history = ["init", input_prompt]
59
-
60
- print(history)
61
- print(output)
62
- return output
63
-
64
- app = FastAPI()
65
-
66
- @app.post("/api/generate")
67
- async def generate(request: Request):
68
- # Receive the request as JSON
69
- data = await request.json()
70
- # Check if the event is a completed order
71
- if data['message']:
72
- response = generate_text(data['message'], history)
73
- return {"status": "success", "data":response}
74
- else:
75
- # If the event is not what we're looking for, ignore it
76
- return {"status": "ignored"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
protocol.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from
2
+ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3
+ import time
4
+ from typing import Dict, List, Literal, Optional, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from vllm.utils import random_uuid
9
+ from vllm.sampling_params import SamplingParams
10
+
11
+
12
+ class ErrorResponse(BaseModel):
13
+ object: str = "error"
14
+ message: str
15
+ type: str
16
+ param: Optional[str] = None
17
+ code: int
18
+
19
+
20
+ class ModelPermission(BaseModel):
21
+ id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
22
+ object: str = "model_permission"
23
+ created: int = Field(default_factory=lambda: int(time.time()))
24
+ allow_create_engine: bool = False
25
+ allow_sampling: bool = True
26
+ allow_logprobs: bool = True
27
+ allow_search_indices: bool = False
28
+ allow_view: bool = True
29
+ allow_fine_tuning: bool = False
30
+ organization: str = "*"
31
+ group: Optional[str] = None
32
+ is_blocking: str = False
33
+
34
+
35
+ class ModelCard(BaseModel):
36
+ id: str
37
+ object: str = "model"
38
+ created: int = Field(default_factory=lambda: int(time.time()))
39
+ owned_by: str = "vllm"
40
+ root: Optional[str] = None
41
+ parent: Optional[str] = None
42
+ permission: List[ModelPermission] = Field(default_factory=list)
43
+
44
+
45
+ class ModelList(BaseModel):
46
+ object: str = "list"
47
+ data: List[ModelCard] = Field(default_factory=list)
48
+
49
+
50
+ class UsageInfo(BaseModel):
51
+ prompt_tokens: int = 0
52
+ total_tokens: int = 0
53
+ completion_tokens: Optional[int] = 0
54
+
55
+
56
+ class ChatCompletionRequest(BaseModel):
57
+ model: str
58
+ messages: Union[str, List[Dict[str, str]]]
59
+ temperature: Optional[float] = 0.7
60
+ top_p: Optional[float] = 1.0
61
+ n: Optional[int] = 1
62
+ max_tokens: Optional[int] = None
63
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
64
+ stream: Optional[bool] = False
65
+ presence_penalty: Optional[float] = 0.0
66
+ frequency_penalty: Optional[float] = 0.0
67
+ logit_bias: Optional[Dict[str, float]] = None
68
+ user: Optional[str] = None
69
+ # Additional parameters supported by vLLM
70
+ best_of: Optional[int] = None
71
+ top_k: Optional[int] = -1
72
+ ignore_eos: Optional[bool] = False
73
+ use_beam_search: Optional[bool] = False
74
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
75
+ skip_special_tokens: Optional[bool] = True
76
+ spaces_between_special_tokens: Optional[bool] = True
77
+ add_generation_prompt: Optional[bool] = True
78
+ echo: Optional[bool] = False
79
+ repetition_penalty: Optional[float] = 1.0
80
+ min_p: Optional[float] = 0.0
81
+
82
+ def to_sampling_params(self) -> SamplingParams:
83
+ return SamplingParams(
84
+ n=self.n,
85
+ presence_penalty=self.presence_penalty,
86
+ frequency_penalty=self.frequency_penalty,
87
+ repetition_penalty=self.repetition_penalty,
88
+ temperature=self.temperature,
89
+ top_p=self.top_p,
90
+ min_p=self.min_p,
91
+ stop=self.stop,
92
+ stop_token_ids=self.stop_token_ids,
93
+ max_tokens=self.max_tokens,
94
+ best_of=self.best_of,
95
+ top_k=self.top_k,
96
+ ignore_eos=self.ignore_eos,
97
+ use_beam_search=self.use_beam_search,
98
+ skip_special_tokens=self.skip_special_tokens,
99
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
100
+ )
101
+
102
+
103
+ class CompletionRequest(BaseModel):
104
+ model: str
105
+ # a string, array of strings, array of tokens, or array of token arrays
106
+ prompt: Union[List[int], List[List[int]], str, List[str]]
107
+ suffix: Optional[str] = None
108
+ max_tokens: Optional[int] = 16
109
+ temperature: Optional[float] = 1.0
110
+ top_p: Optional[float] = 1.0
111
+ n: Optional[int] = 1
112
+ stream: Optional[bool] = False
113
+ logprobs: Optional[int] = None
114
+ echo: Optional[bool] = False
115
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
116
+ presence_penalty: Optional[float] = 0.0
117
+ frequency_penalty: Optional[float] = 0.0
118
+ best_of: Optional[int] = None
119
+ logit_bias: Optional[Dict[str, float]] = None
120
+ user: Optional[str] = None
121
+ # Additional parameters supported by vLLM
122
+ top_k: Optional[int] = -1
123
+ ignore_eos: Optional[bool] = False
124
+ use_beam_search: Optional[bool] = False
125
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
126
+ skip_special_tokens: Optional[bool] = True
127
+ spaces_between_special_tokens: Optional[bool] = True
128
+ repetition_penalty: Optional[float] = 1.0
129
+ min_p: Optional[float] = 0.0
130
+
131
+ def to_sampling_params(self):
132
+ echo_without_generation = self.echo and self.max_tokens == 0
133
+
134
+ return SamplingParams(
135
+ n=self.n,
136
+ best_of=self.best_of,
137
+ presence_penalty=self.presence_penalty,
138
+ frequency_penalty=self.frequency_penalty,
139
+ repetition_penalty=self.repetition_penalty,
140
+ temperature=self.temperature,
141
+ top_p=self.top_p,
142
+ top_k=self.top_k,
143
+ min_p=self.min_p,
144
+ stop=self.stop,
145
+ stop_token_ids=self.stop_token_ids,
146
+ ignore_eos=self.ignore_eos,
147
+ max_tokens=self.max_tokens if not echo_without_generation else 1,
148
+ logprobs=self.logprobs,
149
+ use_beam_search=self.use_beam_search,
150
+ prompt_logprobs=self.logprobs if self.echo else None,
151
+ skip_special_tokens=self.skip_special_tokens,
152
+ spaces_between_special_tokens=(self.spaces_between_special_tokens),
153
+ )
154
+
155
+
156
+ class LogProbs(BaseModel):
157
+ text_offset: List[int] = Field(default_factory=list)
158
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
159
+ tokens: List[str] = Field(default_factory=list)
160
+ top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
161
+
162
+
163
+ class CompletionResponseChoice(BaseModel):
164
+ index: int
165
+ text: str
166
+ logprobs: Optional[LogProbs] = None
167
+ finish_reason: Optional[Literal["stop", "length"]] = None
168
+
169
+
170
+ class CompletionResponse(BaseModel):
171
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
172
+ object: str = "text_completion"
173
+ created: int = Field(default_factory=lambda: int(time.time()))
174
+ model: str
175
+ choices: List[CompletionResponseChoice]
176
+ usage: UsageInfo
177
+
178
+
179
+ class CompletionResponseStreamChoice(BaseModel):
180
+ index: int
181
+ text: str
182
+ logprobs: Optional[LogProbs] = None
183
+ finish_reason: Optional[Literal["stop", "length"]] = None
184
+
185
+
186
+ class CompletionStreamResponse(BaseModel):
187
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
188
+ object: str = "text_completion"
189
+ created: int = Field(default_factory=lambda: int(time.time()))
190
+ model: str
191
+ choices: List[CompletionResponseStreamChoice]
192
+ usage: Optional[UsageInfo] = Field(default=None)
193
+
194
+
195
+ class ChatMessage(BaseModel):
196
+ role: str
197
+ content: str
198
+
199
+
200
+ class ChatCompletionResponseChoice(BaseModel):
201
+ index: int
202
+ message: ChatMessage
203
+ finish_reason: Optional[Literal["stop", "length"]] = None
204
+
205
+
206
+ class ChatCompletionResponse(BaseModel):
207
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
208
+ object: str = "chat.completion"
209
+ created: int = Field(default_factory=lambda: int(time.time()))
210
+ model: str
211
+ choices: List[ChatCompletionResponseChoice]
212
+ usage: UsageInfo
213
+
214
+
215
+ class DeltaMessage(BaseModel):
216
+ role: Optional[str] = None
217
+ content: Optional[str] = None
218
+
219
+
220
+ class ChatCompletionResponseStreamChoice(BaseModel):
221
+ index: int
222
+ delta: DeltaMessage
223
+ finish_reason: Optional[Literal["stop", "length"]] = None
224
+
225
+
226
+ class ChatCompletionStreamResponse(BaseModel):
227
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
228
+ object: str = "chat.completion.chunk"
229
+ created: int = Field(default_factory=lambda: int(time.time()))
230
+ model: str
231
+ choices: List[ChatCompletionResponseStreamChoice]
232
+ usage: Optional[UsageInfo] = Field(default=None)
serving_chat.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import codecs
3
+ from fastapi import Request
4
+ from typing import AsyncGenerator, AsyncIterator, Union
5
+ from vllm.logger import init_logger
6
+ from vllm.utils import random_uuid
7
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
8
+ from protocol import (
9
+ ChatCompletionRequest, ChatCompletionResponse,
10
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
11
+ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
12
+ UsageInfo)
13
+ from vllm.outputs import RequestOutput
14
+ from serving_engine import OpenAIServing
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ class OpenAIServingChat(OpenAIServing):
20
+
21
+ def __init__(self,
22
+ engine: AsyncLLMEngine,
23
+ served_model: str,
24
+ response_role: str,
25
+ chat_template=None):
26
+ super().__init__(engine=engine, served_model=served_model)
27
+ self.response_role = response_role
28
+ self._load_chat_template(chat_template)
29
+
30
+ async def create_chat_completion(
31
+ self, request: ChatCompletionRequest, raw_request: Request
32
+ ) -> Union[ErrorResponse, AsyncGenerator[str, None],
33
+ ChatCompletionResponse]:
34
+ """Completion API similar to OpenAI's API.
35
+
36
+ See https://platform.openai.com/docs/api-reference/chat/create
37
+ for the API specification. This API mimics the OpenAI ChatCompletion API.
38
+
39
+ NOTE: Currently we do not support the following features:
40
+ - function_call (Users should implement this by themselves)
41
+ - logit_bias (to be supported by vLLM engine)
42
+ """
43
+ error_check_ret = await self._check_model(request)
44
+ if error_check_ret is not None:
45
+ return error_check_ret
46
+
47
+ if request.logit_bias is not None and len(request.logit_bias) > 0:
48
+ # TODO: support logit_bias in vLLM engine.
49
+ return self.create_error_response(
50
+ "logit_bias is not currently supported")
51
+
52
+ try:
53
+ prompt = self.tokenizer.apply_chat_template(
54
+ conversation=request.messages,
55
+ tokenize=False,
56
+ add_generation_prompt=request.add_generation_prompt)
57
+ except Exception as e:
58
+ logger.error(
59
+ f"Error in applying chat template from request: {str(e)}")
60
+ return self.create_error_response(str(e))
61
+
62
+ request_id = f"cmpl-{random_uuid()}"
63
+ try:
64
+ token_ids = self._validate_prompt_and_tokenize(request,
65
+ prompt=prompt)
66
+ sampling_params = request.to_sampling_params()
67
+ except ValueError as e:
68
+ return self.create_error_response(str(e))
69
+
70
+ result_generator = self.engine.generate(prompt, sampling_params,
71
+ request_id, token_ids)
72
+ # Streaming response
73
+ if request.stream:
74
+ return self.chat_completion_stream_generator(
75
+ request, result_generator, request_id)
76
+ else:
77
+ return await self.chat_completion_full_generator(
78
+ request, raw_request, result_generator, request_id)
79
+
80
+ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
81
+ if request.add_generation_prompt:
82
+ return self.response_role
83
+ else:
84
+ return request.messages[-1].role
85
+
86
+ async def chat_completion_stream_generator(
87
+ self, request: ChatCompletionRequest,
88
+ result_generator: AsyncIterator[RequestOutput], request_id: str
89
+ ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
90
+
91
+ model_name = request.model
92
+ created_time = int(time.monotonic())
93
+ chunk_object_type = "chat.completion.chunk"
94
+
95
+ # Send first response for each request.n (index) with the role
96
+ role = self.get_chat_request_role(request)
97
+ for i in range(request.n):
98
+ choice_data = ChatCompletionResponseStreamChoice(
99
+ index=i, delta=DeltaMessage(role=role), finish_reason=None)
100
+ chunk = ChatCompletionStreamResponse(id=request_id,
101
+ object=chunk_object_type,
102
+ created=created_time,
103
+ choices=[choice_data],
104
+ model=model_name)
105
+ data = chunk.model_dump_json(exclude_unset=True)
106
+ yield f"data: {data}\n\n"
107
+
108
+ # Send response to echo the input portion of the last message
109
+ if request.echo:
110
+ last_msg_content = ""
111
+ if request.messages and isinstance(
112
+ request.messages, list) and request.messages[-1].get(
113
+ "content") and request.messages[-1].get(
114
+ "role") == role:
115
+ last_msg_content = request.messages[-1]["content"]
116
+ if last_msg_content:
117
+ for i in range(request.n):
118
+ choice_data = ChatCompletionResponseStreamChoice(
119
+ index=i,
120
+ delta=DeltaMessage(content=last_msg_content),
121
+ finish_reason=None)
122
+ chunk = ChatCompletionStreamResponse(
123
+ id=request_id,
124
+ object=chunk_object_type,
125
+ created=created_time,
126
+ choices=[choice_data],
127
+ model=model_name)
128
+ data = chunk.model_dump_json(exclude_unset=True)
129
+ yield f"data: {data}\n\n"
130
+
131
+ # Send response for each token for each request.n (index)
132
+ previous_texts = [""] * request.n
133
+ previous_num_tokens = [0] * request.n
134
+ finish_reason_sent = [False] * request.n
135
+ async for res in result_generator:
136
+ res: RequestOutput
137
+ for output in res.outputs:
138
+ i = output.index
139
+
140
+ if finish_reason_sent[i]:
141
+ continue
142
+
143
+ delta_text = output.text[len(previous_texts[i]):]
144
+ previous_texts[i] = output.text
145
+ previous_num_tokens[i] = len(output.token_ids)
146
+
147
+ if output.finish_reason is None:
148
+ # Send token-by-token response for each request.n
149
+ choice_data = ChatCompletionResponseStreamChoice(
150
+ index=i,
151
+ delta=DeltaMessage(content=delta_text),
152
+ finish_reason=None)
153
+ chunk = ChatCompletionStreamResponse(
154
+ id=request_id,
155
+ object=chunk_object_type,
156
+ created=created_time,
157
+ choices=[choice_data],
158
+ model=model_name)
159
+ data = chunk.model_dump_json(exclude_unset=True)
160
+ yield f"data: {data}\n\n"
161
+ else:
162
+ # Send the finish response for each request.n only once
163
+ prompt_tokens = len(res.prompt_token_ids)
164
+ final_usage = UsageInfo(
165
+ prompt_tokens=prompt_tokens,
166
+ completion_tokens=previous_num_tokens[i],
167
+ total_tokens=prompt_tokens + previous_num_tokens[i],
168
+ )
169
+ choice_data = ChatCompletionResponseStreamChoice(
170
+ index=i,
171
+ delta=DeltaMessage(content=delta_text),
172
+ finish_reason=output.finish_reason)
173
+ chunk = ChatCompletionStreamResponse(
174
+ id=request_id,
175
+ object=chunk_object_type,
176
+ created=created_time,
177
+ choices=[choice_data],
178
+ model=model_name)
179
+ if final_usage is not None:
180
+ chunk.usage = final_usage
181
+ data = chunk.model_dump_json(exclude_unset=True,
182
+ exclude_none=True)
183
+ yield f"data: {data}\n\n"
184
+ finish_reason_sent[i] = True
185
+ # Send the final done message after all response.n are finished
186
+ yield "data: [DONE]\n\n"
187
+
188
+ async def chat_completion_full_generator(
189
+ self, request: ChatCompletionRequest, raw_request: Request,
190
+ result_generator: AsyncIterator[RequestOutput],
191
+ request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
192
+
193
+ model_name = request.model
194
+ created_time = int(time.monotonic())
195
+ final_res: RequestOutput = None
196
+
197
+ async for res in result_generator:
198
+ if await raw_request.is_disconnected():
199
+ # Abort the request if the client disconnects.
200
+ await self.engine.abort(request_id)
201
+ return self.create_error_response("Client disconnected")
202
+ final_res = res
203
+ assert final_res is not None
204
+
205
+ choices = []
206
+ role = self.get_chat_request_role(request)
207
+ for output in final_res.outputs:
208
+ choice_data = ChatCompletionResponseChoice(
209
+ index=output.index,
210
+ message=ChatMessage(role=role, content=output.text),
211
+ finish_reason=output.finish_reason,
212
+ )
213
+ choices.append(choice_data)
214
+
215
+ if request.echo:
216
+ last_msg_content = ""
217
+ if request.messages and isinstance(
218
+ request.messages, list) and request.messages[-1].get(
219
+ "content") and request.messages[-1].get(
220
+ "role") == role:
221
+ last_msg_content = request.messages[-1]["content"]
222
+
223
+ for choice in choices:
224
+ full_message = last_msg_content + choice.message.content
225
+ choice.message.content = full_message
226
+
227
+ num_prompt_tokens = len(final_res.prompt_token_ids)
228
+ num_generated_tokens = sum(
229
+ len(output.token_ids) for output in final_res.outputs)
230
+ usage = UsageInfo(
231
+ prompt_tokens=num_prompt_tokens,
232
+ completion_tokens=num_generated_tokens,
233
+ total_tokens=num_prompt_tokens + num_generated_tokens,
234
+ )
235
+ response = ChatCompletionResponse(
236
+ id=request_id,
237
+ created=created_time,
238
+ model=model_name,
239
+ choices=choices,
240
+ usage=usage,
241
+ )
242
+
243
+ return response
244
+
245
+ def _load_chat_template(self, chat_template):
246
+ if chat_template is not None:
247
+ try:
248
+ with open(chat_template, "r") as f:
249
+ self.tokenizer.chat_template = f.read()
250
+ except OSError:
251
+ # If opening a file fails, set chat template to be args to
252
+ # ensure we decode so our escape are interpreted correctly
253
+ self.tokenizer.chat_template = codecs.decode(
254
+ chat_template, "unicode_escape")
255
+
256
+ logger.info(
257
+ f"Using supplied chat template:\n{self.tokenizer.chat_template}"
258
+ )
259
+ elif self.tokenizer.chat_template is not None:
260
+ logger.info(
261
+ f"Using default chat template:\n{self.tokenizer.chat_template}"
262
+ )
263
+ else:
264
+ logger.warning(
265
+ "No chat template provided. Chat API will not work.")
serving_completion.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from fastapi import Request
3
+ from typing import AsyncGenerator, AsyncIterator
4
+ from vllm.logger import init_logger
5
+ from vllm.utils import random_uuid
6
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
7
+ from .protocol import (
8
+ CompletionRequest,
9
+ CompletionResponse,
10
+ CompletionResponseChoice,
11
+ CompletionResponseStreamChoice,
12
+ CompletionStreamResponse,
13
+ LogProbs,
14
+ UsageInfo,
15
+ )
16
+ from vllm.outputs import RequestOutput
17
+ from serving_engine import OpenAIServing
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ async def completion_stream_generator(
23
+ request: CompletionRequest,
24
+ result_generator: AsyncIterator[RequestOutput],
25
+ echo_without_generation, create_logprobs_fn, request_id, created_time,
26
+ model_name) -> AsyncGenerator[str, None]:
27
+ previous_texts = [""] * request.n
28
+ previous_num_tokens = [0] * request.n
29
+ has_echoed = [False] * request.n
30
+
31
+ async for res in result_generator:
32
+ # TODO: handle client disconnect for streaming
33
+ for output in res.outputs:
34
+ i = output.index
35
+ delta_text = output.text[len(previous_texts[i]):]
36
+ token_ids = output.token_ids[previous_num_tokens[i]:]
37
+ if request.logprobs is not None:
38
+ top_logprobs = output.logprobs[previous_num_tokens[i]:]
39
+ else:
40
+ top_logprobs = None
41
+ offsets = len(previous_texts[i])
42
+ if request.echo and not has_echoed[i]:
43
+ if not echo_without_generation:
44
+ delta_text = res.prompt + delta_text
45
+ token_ids = res.prompt_token_ids + token_ids
46
+ if top_logprobs:
47
+ top_logprobs = res.prompt_logprobs + top_logprobs
48
+ else: # only just return the prompt
49
+ delta_text = res.prompt
50
+ token_ids = res.prompt_token_ids
51
+ if top_logprobs:
52
+ top_logprobs = res.prompt_logprobs
53
+ has_echoed[i] = True
54
+ if request.logprobs is not None:
55
+ logprobs = create_logprobs_fn(
56
+ token_ids=token_ids,
57
+ top_logprobs=top_logprobs,
58
+ num_output_top_logprobs=request.logprobs,
59
+ initial_text_offset=offsets,
60
+ )
61
+ else:
62
+ logprobs = None
63
+ previous_texts[i] = output.text
64
+ previous_num_tokens[i] = len(output.token_ids)
65
+ finish_reason = output.finish_reason
66
+ response_json = CompletionStreamResponse(
67
+ id=request_id,
68
+ created=created_time,
69
+ model=model_name,
70
+ choices=[
71
+ CompletionResponseStreamChoice(
72
+ index=i,
73
+ text=delta_text,
74
+ logprobs=logprobs,
75
+ finish_reason=finish_reason,
76
+ )
77
+ ]).model_dump_json(exclude_unset=True)
78
+ yield f"data: {response_json}\n\n"
79
+
80
+ if output.finish_reason is not None:
81
+ logprobs = LogProbs() if request.logprobs is not None else None
82
+ prompt_tokens = len(res.prompt_token_ids)
83
+ completion_tokens = len(output.token_ids)
84
+ final_usage = UsageInfo(
85
+ prompt_tokens=prompt_tokens,
86
+ completion_tokens=completion_tokens,
87
+ total_tokens=prompt_tokens + completion_tokens,
88
+ )
89
+ response_json = CompletionStreamResponse(
90
+ id=request_id,
91
+ created=created_time,
92
+ model=model_name,
93
+ choices=[
94
+ CompletionResponseStreamChoice(
95
+ index=i,
96
+ text="",
97
+ logprobs=logprobs,
98
+ finish_reason=output.finish_reason,
99
+ )
100
+ ],
101
+ usage=final_usage,
102
+ ).model_dump_json(exclude_unset=True)
103
+ yield f"data: {response_json}\n\n"
104
+
105
+ yield "data: [DONE]\n\n"
106
+
107
+
108
+ def parse_prompt_format(prompt) -> tuple[bool, list]:
109
+ # get the prompt, openai supports the following
110
+ # "a string, array of strings, array of tokens, or array of token arrays."
111
+ prompt_is_tokens = False
112
+ prompts = [prompt] # case 1: a string
113
+ if isinstance(prompt, list):
114
+ if len(prompt) == 0:
115
+ raise ValueError("please provide at least one prompt")
116
+ elif isinstance(prompt[0], str):
117
+ prompt_is_tokens = False
118
+ prompts = prompt # case 2: array of strings
119
+ elif isinstance(prompt[0], int):
120
+ prompt_is_tokens = True
121
+ prompts = [prompt] # case 3: array of tokens
122
+ elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
123
+ prompt_is_tokens = True
124
+ prompts = prompt # case 4: array of token arrays
125
+ else:
126
+ raise ValueError(
127
+ "prompt must be a string, array of strings, array of tokens, or array of token arrays"
128
+ )
129
+ return prompt_is_tokens, prompts
130
+
131
+
132
+ def request_output_to_completion_response(final_res: RequestOutput, request,
133
+ echo_without_generation,
134
+ create_logprobs_fn, request_id,
135
+ created_time,
136
+ model_name) -> CompletionResponse:
137
+ assert final_res is not None
138
+ choices = []
139
+ prompt_token_ids = final_res.prompt_token_ids
140
+ prompt_logprobs = final_res.prompt_logprobs
141
+ prompt_text = final_res.prompt
142
+ for output in final_res.outputs:
143
+ if request.logprobs is not None:
144
+ if not echo_without_generation:
145
+ token_ids = output.token_ids
146
+ top_logprobs = output.logprobs
147
+ if request.echo:
148
+ token_ids = prompt_token_ids + token_ids
149
+ top_logprobs = prompt_logprobs + top_logprobs
150
+ else:
151
+ token_ids = prompt_token_ids
152
+ top_logprobs = prompt_logprobs
153
+ logprobs = create_logprobs_fn(
154
+ token_ids=token_ids,
155
+ top_logprobs=top_logprobs,
156
+ num_output_top_logprobs=request.logprobs,
157
+ )
158
+ else:
159
+ logprobs = None
160
+ if not echo_without_generation:
161
+ output_text = output.text
162
+ if request.echo:
163
+ output_text = prompt_text + output_text
164
+ else:
165
+ output_text = prompt_text
166
+ choice_data = CompletionResponseChoice(
167
+ index=output.index,
168
+ text=output_text,
169
+ logprobs=logprobs,
170
+ finish_reason=output.finish_reason,
171
+ )
172
+ choices.append(choice_data)
173
+
174
+ num_prompt_tokens = len(final_res.prompt_token_ids)
175
+ num_generated_tokens = sum(
176
+ len(output.token_ids) for output in final_res.outputs)
177
+ usage = UsageInfo(
178
+ prompt_tokens=num_prompt_tokens,
179
+ completion_tokens=num_generated_tokens,
180
+ total_tokens=num_prompt_tokens + num_generated_tokens,
181
+ )
182
+
183
+ return CompletionResponse(
184
+ id=request_id,
185
+ created=created_time,
186
+ model=model_name,
187
+ choices=choices,
188
+ usage=usage,
189
+ )
190
+
191
+
192
+ class OpenAIServingCompletion(OpenAIServing):
193
+
194
+ def __init__(self, engine: AsyncLLMEngine, served_model: str):
195
+ super().__init__(engine=engine, served_model=served_model)
196
+
197
+ async def create_completion(self, request: CompletionRequest,
198
+ raw_request: Request):
199
+ """Completion API similar to OpenAI's API.
200
+
201
+ See https://platform.openai.com/docs/api-reference/completions/create
202
+ for the API specification. This API mimics the OpenAI Completion API.
203
+
204
+ NOTE: Currently we do not support the following features:
205
+ - suffix (the language models we currently support do not support
206
+ suffix)
207
+ - logit_bias (to be supported by vLLM engine)
208
+ """
209
+ error_check_ret = await self._check_model(request)
210
+ if error_check_ret is not None:
211
+ return error_check_ret
212
+
213
+ # OpenAI API supports echoing the prompt when max_tokens is 0.
214
+ echo_without_generation = request.echo and request.max_tokens == 0
215
+
216
+ # Return error for unsupported features.
217
+ if request.suffix is not None:
218
+ return self.create_error_response(
219
+ "suffix is not currently supported")
220
+ if request.logit_bias is not None and len(request.logit_bias) > 0:
221
+ return self.create_error_response(
222
+ "logit_bias is not currently supported")
223
+
224
+ model_name = request.model
225
+ request_id = f"cmpl-{random_uuid()}"
226
+ created_time = int(time.monotonic())
227
+
228
+ # Schedule the request and get the result generator.
229
+ try:
230
+ sampling_params = request.to_sampling_params()
231
+
232
+ prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
233
+
234
+ if len(prompts) > 1:
235
+ raise ValueError(
236
+ "Batching in completion API is not supported.")
237
+ prompt = prompts[0]
238
+
239
+ if prompt_is_tokens:
240
+ input_ids = self._validate_prompt_and_tokenize(
241
+ request, prompt_ids=prompt)
242
+ else:
243
+ input_ids = self._validate_prompt_and_tokenize(request,
244
+ prompt=prompt)
245
+
246
+ result_generator = self.engine.generate(None,
247
+ sampling_params,
248
+ request_id,
249
+ prompt_token_ids=input_ids)
250
+ except ValueError as e:
251
+ return self.create_error_response(str(e))
252
+
253
+ # Similar to the OpenAI API, when n != best_of, we do not stream the
254
+ # results. In addition, we do not stream the results when use beam search.
255
+ stream = (request.stream
256
+ and (request.best_of is None or request.n == request.best_of)
257
+ and not request.use_beam_search)
258
+
259
+ # Streaming response
260
+ if stream:
261
+ return completion_stream_generator(request, result_generator,
262
+ echo_without_generation,
263
+ self._create_logprobs,
264
+ request_id, created_time,
265
+ model_name)
266
+
267
+ # Non-streaming response
268
+ final_res: RequestOutput = None
269
+ async for res in result_generator:
270
+ if await raw_request.is_disconnected():
271
+ # Abort the request if the client disconnects.
272
+ await self.engine.abort(request_id)
273
+ return self.create_error_response("Client disconnected")
274
+ final_res = res
275
+ response = request_output_to_completion_response(
276
+ final_res, request, echo_without_generation, self._create_logprobs,
277
+ request_id, created_time, model_name)
278
+
279
+ # When user requests streaming but we don't stream, we still need to
280
+ # return a streaming response with a single event.
281
+ if request.stream:
282
+ response_json = response.model_dump_json()
283
+
284
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
285
+ yield f"data: {response_json}\n\n"
286
+ yield "data: [DONE]\n\n"
287
+
288
+ return fake_stream_generator()
289
+
290
+ return response
serving_engine.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from http import HTTPStatus
3
+ from typing import Dict, List, Optional, Union
4
+ from vllm.logger import init_logger
5
+ from vllm.transformers_utils.tokenizer import get_tokenizer
6
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
7
+ from protocol import (CompletionRequest,
8
+ ChatCompletionRequest,
9
+ ErrorResponse, LogProbs,
10
+ ModelCard, ModelList,
11
+ ModelPermission)
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class OpenAIServing:
17
+
18
+ def __init__(self, engine: AsyncLLMEngine, served_model: str):
19
+ self.engine = engine
20
+ self.served_model = served_model
21
+
22
+ self.max_model_len = 0
23
+ self.tokenizer = None
24
+
25
+ try:
26
+ event_loop = asyncio.get_running_loop()
27
+ except RuntimeError:
28
+ event_loop = None
29
+
30
+ if event_loop is not None and event_loop.is_running(
31
+ ): # If the current is instanced by Ray Serve, there is already a running event loop
32
+ event_loop.create_task(self._post_init())
33
+ else: # When using single vLLM without engine_use_ray
34
+ asyncio.run(self._post_init())
35
+
36
+ async def _post_init(self):
37
+ engine_model_config = await self.engine.get_model_config()
38
+ self.max_model_len = engine_model_config.max_model_len
39
+
40
+ # A separate tokenizer to map token IDs to strings.
41
+ self.tokenizer = get_tokenizer(
42
+ engine_model_config.tokenizer,
43
+ tokenizer_mode=engine_model_config.tokenizer_mode,
44
+ trust_remote_code=engine_model_config.trust_remote_code)
45
+
46
+ async def show_available_models(self) -> ModelList:
47
+ """Show available models. Right now we only have one model."""
48
+ model_cards = [
49
+ ModelCard(id=self.served_model,
50
+ root=self.served_model,
51
+ permission=[ModelPermission()])
52
+ ]
53
+ return ModelList(data=model_cards)
54
+
55
+ def _create_logprobs(
56
+ self,
57
+ token_ids: List[int],
58
+ top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
59
+ num_output_top_logprobs: Optional[int] = None,
60
+ initial_text_offset: int = 0,
61
+ ) -> LogProbs:
62
+ """Create OpenAI-style logprobs."""
63
+ logprobs = LogProbs()
64
+ last_token_len = 0
65
+ if num_output_top_logprobs:
66
+ logprobs.top_logprobs = []
67
+ for i, token_id in enumerate(token_ids):
68
+ step_top_logprobs = top_logprobs[i]
69
+ if step_top_logprobs is not None:
70
+ token_logprob = step_top_logprobs[token_id]
71
+ else:
72
+ token_logprob = None
73
+ token = self.tokenizer.convert_ids_to_tokens(token_id)
74
+ logprobs.tokens.append(token)
75
+ logprobs.token_logprobs.append(token_logprob)
76
+ if len(logprobs.text_offset) == 0:
77
+ logprobs.text_offset.append(initial_text_offset)
78
+ else:
79
+ logprobs.text_offset.append(logprobs.text_offset[-1] +
80
+ last_token_len)
81
+ last_token_len = len(token)
82
+
83
+ if num_output_top_logprobs:
84
+ logprobs.top_logprobs.append({
85
+ self.tokenizer.convert_ids_to_tokens(i): p
86
+ for i, p in step_top_logprobs.items()
87
+ } if step_top_logprobs else None)
88
+ return logprobs
89
+
90
+ def create_error_response(
91
+ self,
92
+ message: str,
93
+ err_type: str = "BadRequestError",
94
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
95
+ return ErrorResponse(message=message,
96
+ type=err_type,
97
+ code=status_code.value)
98
+
99
+ async def _check_model(self, request) -> Optional[ErrorResponse]:
100
+ if request.model == self.served_model:
101
+ return
102
+ return self.create_error_response(
103
+ message=f"The model `{request.model}` does not exist.",
104
+ err_type="NotFoundError",
105
+ status_code=HTTPStatus.NOT_FOUND)
106
+
107
+ def _validate_prompt_and_tokenize(
108
+ self,
109
+ request: Union[ChatCompletionRequest, CompletionRequest],
110
+ prompt: Optional[str] = None,
111
+ prompt_ids: Optional[List[int]] = None) -> List[int]:
112
+ if not (prompt or prompt_ids):
113
+ raise ValueError("Either prompt or prompt_ids should be provided.")
114
+ if (prompt and prompt_ids):
115
+ raise ValueError(
116
+ "Only one of prompt or prompt_ids should be provided.")
117
+
118
+ input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
119
+ prompt).input_ids
120
+ token_num = len(input_ids)
121
+
122
+ if request.max_tokens is None:
123
+ request.max_tokens = self.max_model_len - token_num
124
+
125
+ if token_num + request.max_tokens > self.max_model_len:
126
+ raise ValueError(
127
+ f"This model's maximum context length is {self.max_model_len} tokens. "
128
+ f"However, you requested {request.max_tokens + token_num} tokens "
129
+ f"({token_num} in the messages, "
130
+ f"{request.max_tokens} in the completion). "
131
+ f"Please reduce the length of the messages or completion.", )
132
+ else:
133
+ return input_ids