test-docker / serving_engine.py
sofianhw's picture
change to 0.5.3 support llama 3.1
5491dc5
raw
history blame contribute delete
No virus
15 kB
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
logger = init_logger(__name__)
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing:
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__()
self.engine = engine
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
max_model_len=self.max_model_len,
root=self.served_model_names[0],
permission=[ModelPermission()])
for served_model_name in self.served_model_names
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)