import time from fastapi import Request from typing import AsyncGenerator, AsyncIterator from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine from protocol import ( CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, LogProbs, UsageInfo, ) from vllm.outputs import RequestOutput from serving_engine import OpenAIServing logger = init_logger(__name__) async def completion_stream_generator( request: CompletionRequest, result_generator: AsyncIterator[RequestOutput], echo_without_generation, create_logprobs_fn, request_id, created_time, model_name) -> AsyncGenerator[str, None]: previous_texts = [""] * request.n previous_num_tokens = [0] * request.n has_echoed = [False] * request.n async for res in result_generator: # TODO: handle client disconnect for streaming for output in res.outputs: i = output.index delta_text = output.text[len(previous_texts[i]):] token_ids = output.token_ids[previous_num_tokens[i]:] if request.logprobs is not None: top_logprobs = output.logprobs[previous_num_tokens[i]:] else: top_logprobs = None offsets = len(previous_texts[i]) if request.echo and not has_echoed[i]: if not echo_without_generation: delta_text = res.prompt + delta_text token_ids = res.prompt_token_ids + token_ids if top_logprobs: top_logprobs = res.prompt_logprobs + top_logprobs else: # only just return the prompt delta_text = res.prompt token_ids = res.prompt_token_ids if top_logprobs: top_logprobs = res.prompt_logprobs has_echoed[i] = True if request.logprobs is not None: logprobs = create_logprobs_fn( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, initial_text_offset=offsets, ) else: logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) finish_reason = output.finish_reason response_json = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[ CompletionResponseStreamChoice( index=i, text=delta_text, logprobs=logprobs, finish_reason=finish_reason, ) ]).model_dump_json(exclude_unset=True) yield f"data: {response_json}\n\n" if output.finish_reason is not None: logprobs = LogProbs() if request.logprobs is not None else None prompt_tokens = len(res.prompt_token_ids) completion_tokens = len(output.token_ids) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) response_json = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[ CompletionResponseStreamChoice( index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason, ) ], usage=final_usage, ).model_dump_json(exclude_unset=True) yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" def parse_prompt_format(prompt) -> tuple[bool, list]: # get the prompt, openai supports the following # "a string, array of strings, array of tokens, or array of token arrays." prompt_is_tokens = False prompts = [prompt] # case 1: a string if isinstance(prompt, list): if len(prompt) == 0: raise ValueError("please provide at least one prompt") elif isinstance(prompt[0], str): prompt_is_tokens = False prompts = prompt # case 2: array of strings elif isinstance(prompt[0], int): prompt_is_tokens = True prompts = [prompt] # case 3: array of tokens elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): prompt_is_tokens = True prompts = prompt # case 4: array of token arrays else: raise ValueError( "prompt must be a string, array of strings, array of tokens, or array of token arrays" ) return prompt_is_tokens, prompts def request_output_to_completion_response(final_res: RequestOutput, request, echo_without_generation, create_logprobs_fn, request_id, created_time, model_name) -> CompletionResponse: assert final_res is not None choices = [] prompt_token_ids = final_res.prompt_token_ids prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt for output in final_res.outputs: if request.logprobs is not None: if not echo_without_generation: token_ids = output.token_ids top_logprobs = output.logprobs if request.echo: token_ids = prompt_token_ids + token_ids top_logprobs = prompt_logprobs + top_logprobs else: token_ids = prompt_token_ids top_logprobs = prompt_logprobs logprobs = create_logprobs_fn( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, ) else: logprobs = None if not echo_without_generation: output_text = output.text if request.echo: output_text = prompt_text + output_text else: output_text = prompt_text choice_data = CompletionResponseChoice( index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data) num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, ) class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str): super().__init__(engine=engine, served_model=served_model) async def create_completion(self, request: CompletionRequest, raw_request: Request): """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. NOTE: Currently we do not support the following features: - suffix (the language models we currently support do not support suffix) - logit_bias (to be supported by vLLM engine) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret # OpenAI API supports echoing the prompt when max_tokens is 0. echo_without_generation = request.echo and request.max_tokens == 0 # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( "suffix is not currently supported") if request.logit_bias is not None and len(request.logit_bias) > 0: return self.create_error_response( "logit_bias is not currently supported") model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. try: sampling_params = request.to_sampling_params() prompt_is_tokens, prompts = parse_prompt_format(request.prompt) if len(prompts) > 1: raise ValueError( "Batching in completion API is not supported.") prompt = prompts[0] if prompt_is_tokens: input_ids = self._validate_prompt_and_tokenize( request, prompt_ids=prompt) else: input_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) result_generator = self.engine.generate(None, sampling_params, request_id, prompt_token_ids=input_ids) except ValueError as e: return self.create_error_response(str(e)) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. stream = (request.stream and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) # Streaming response if stream: return completion_stream_generator(request, result_generator, echo_without_generation, self._create_logprobs, request_id, created_time, model_name) # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(request_id) return self.create_error_response("Client disconnected") final_res = res response = request_output_to_completion_response( final_res, request, echo_without_generation, self._create_logprobs, request_id, created_time, model_name) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. if request.stream: response_json = response.model_dump_json() async def fake_stream_generator() -> AsyncGenerator[str, None]: yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" return fake_stream_generator() return response