Spaces:
Paused
Paused
import base64 | |
import time | |
from typing import AsyncIterator, List, Optional, Tuple, cast | |
import numpy as np | |
from fastapi import Request | |
from vllm.config import ModelConfig | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.entrypoints.logger import RequestLogger | |
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, | |
EmbeddingResponse, | |
EmbeddingResponseData, UsageInfo) | |
from vllm.entrypoints.openai.serving_engine import OpenAIServing | |
from vllm.logger import init_logger | |
from vllm.outputs import EmbeddingRequestOutput | |
from vllm.utils import merge_async_iterators, random_uuid | |
logger = init_logger(__name__) | |
TypeTokenIDs = List[int] | |
def request_output_to_embedding_response( | |
final_res_batch: List[EmbeddingRequestOutput], request_id: str, | |
created_time: int, model_name: str, | |
encoding_format: str) -> EmbeddingResponse: | |
data: List[EmbeddingResponseData] = [] | |
num_prompt_tokens = 0 | |
for idx, final_res in enumerate(final_res_batch): | |
prompt_token_ids = final_res.prompt_token_ids | |
embedding = final_res.outputs.embedding | |
if encoding_format == "base64": | |
embedding_bytes = np.array(embedding).tobytes() | |
embedding = base64.b64encode(embedding_bytes).decode("utf-8") | |
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) | |
data.append(embedding_data) | |
num_prompt_tokens += len(prompt_token_ids) | |
usage = UsageInfo( | |
prompt_tokens=num_prompt_tokens, | |
total_tokens=num_prompt_tokens, | |
) | |
return EmbeddingResponse( | |
id=request_id, | |
created=created_time, | |
model=model_name, | |
data=data, | |
usage=usage, | |
) | |
class OpenAIServingEmbedding(OpenAIServing): | |
def __init__( | |
self, | |
engine: AsyncLLMEngine, | |
model_config: ModelConfig, | |
served_model_names: List[str], | |
*, | |
request_logger: Optional[RequestLogger], | |
): | |
super().__init__(engine=engine, | |
model_config=model_config, | |
served_model_names=served_model_names, | |
lora_modules=None, | |
prompt_adapters=None, | |
request_logger=request_logger) | |
self._check_embedding_mode(model_config.embedding_mode) | |
async def create_embedding(self, request: EmbeddingRequest, | |
raw_request: Request): | |
"""Completion API similar to OpenAI's API. | |
See https://platform.openai.com/docs/api-reference/embeddings/create | |
for the API specification. This API mimics the OpenAI Embedding API. | |
""" | |
error_check_ret = await self._check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
encoding_format = (request.encoding_format | |
if request.encoding_format else "float") | |
if request.dimensions is not None: | |
return self.create_error_response( | |
"dimensions is currently not supported") | |
model_name = request.model | |
request_id = f"embd-{random_uuid()}" | |
created_time = int(time.monotonic()) | |
# Schedule the request and get the result generator. | |
generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] | |
try: | |
( | |
lora_request, | |
prompt_adapter_request, | |
) = self._maybe_get_adapters(request) | |
tokenizer = await self.engine.get_tokenizer(lora_request) | |
pooling_params = request.to_pooling_params() | |
prompts = list( | |
self._tokenize_prompt_input_or_inputs( | |
request, | |
tokenizer, | |
request.input, | |
)) | |
for i, prompt_inputs in enumerate(prompts): | |
request_id_item = f"{request_id}-{i}" | |
self._log_inputs(request_id_item, | |
prompt_inputs, | |
params=pooling_params, | |
lora_request=lora_request, | |
prompt_adapter_request=prompt_adapter_request) | |
if prompt_adapter_request is not None: | |
raise NotImplementedError( | |
"Prompt adapter is not supported " | |
"for embedding models") | |
generator = self.engine.encode( | |
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, | |
pooling_params, | |
request_id_item, | |
lora_request=lora_request, | |
) | |
generators.append(generator) | |
except ValueError as e: | |
# TODO: Use a vllm-specific Validation Error | |
return self.create_error_response(str(e)) | |
result_generator: AsyncIterator[Tuple[ | |
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) | |
# Non-streaming response | |
final_res_batch: List[Optional[EmbeddingRequestOutput]] | |
final_res_batch = [None] * len(prompts) | |
try: | |
async for i, res in result_generator: | |
if await raw_request.is_disconnected(): | |
# Abort the request if the client disconnects. | |
await self.engine.abort(f"{request_id}-{i}") | |
return self.create_error_response("Client disconnected") | |
final_res_batch[i] = res | |
for final_res in final_res_batch: | |
assert final_res is not None | |
final_res_batch_checked = cast(List[EmbeddingRequestOutput], | |
final_res_batch) | |
response = request_output_to_embedding_response( | |
final_res_batch_checked, request_id, created_time, model_name, | |
encoding_format) | |
except ValueError as e: | |
# TODO: Use a vllm-specific Validation Error | |
return self.create_error_response(str(e)) | |
return response | |
def _check_embedding_mode(self, embedding_mode: bool): | |
if not embedding_mode: | |
logger.warning( | |
"embedding_mode is False. Embedding API will not work.") | |
else: | |
logger.info("Activating the server engine with embedding enabled.") |