test-docker / serving_embedding.py
sofianhw's picture
update some code to comply with 0.5.1
8f99309
raw
history blame
5.53 kB
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple
import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding = base64.b64encode(np.array(embedding))
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = (request.encoding_format
if request.encoding_format else "float")
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_ids, prompt_text = prompt_formats
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name,
encoding_format)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool):
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")