import base64 import time from typing import AsyncIterator, List, Optional, Tuple, cast import numpy as np from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) TypeTokenIDs = List[int] def request_output_to_embedding_response( final_res_batch: List[EmbeddingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: str) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): prompt_token_ids = final_res.prompt_token_ids embedding = final_res.outputs.embedding if encoding_format == "base64": embedding_bytes = np.array(embedding).tobytes() embedding = base64.b64encode(embedding_bytes).decode("utf-8") embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids) usage = UsageInfo( prompt_tokens=num_prompt_tokens, total_tokens=num_prompt_tokens, ) return EmbeddingResponse( id=request_id, created=created_time, model=model_name, data=data, usage=usage, ) class OpenAIServingEmbedding(OpenAIServing): def __init__( self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=None, prompt_adapters=None, request_logger=request_logger) self._check_embedding_mode(model_config.embedding_mode) async def create_embedding(self, request: EmbeddingRequest, raw_request: Request): """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret encoding_format = (request.encoding_format if request.encoding_format else "float") if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") model_name = request.model request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] try: ( lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) tokenizer = await self.engine.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() prompts = list( self._tokenize_prompt_input_or_inputs( request, tokenizer, request.input, )) for i, prompt_inputs in enumerate(prompts): request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, prompt_inputs, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) if prompt_adapter_request is not None: raise NotImplementedError( "Prompt adapter is not supported " "for embedding models") generator = self.engine.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, lora_request=lora_request, ) generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res for final_res in final_res_batch: assert final_res is not None final_res_batch_checked = cast(List[EmbeddingRequestOutput], final_res_batch) response = request_output_to_embedding_response( final_res_batch_checked, request_id, created_time, model_name, encoding_format) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) return response def _check_embedding_mode(self, embedding_mode: bool): if not embedding_mode: logger.warning( "embedding_mode is False. Embedding API will not work.") else: logger.info("Activating the server engine with embedding enabled.")