Spaces:
Paused
Paused
import time | |
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, | |
Optional) | |
from typing import Sequence as GenericSequence | |
from typing import Tuple, cast | |
from fastapi import Request | |
from transformers import PreTrainedTokenizer | |
from vllm.config import ModelConfig | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.entrypoints.logger import RequestLogger | |
# yapf conflicts with isort for this block | |
# yapf: disable | |
from vllm.entrypoints.openai.protocol import (CompletionLogProbs, | |
CompletionRequest, | |
CompletionResponse, | |
CompletionResponseChoice, | |
CompletionResponseStreamChoice, | |
CompletionStreamResponse, | |
UsageInfo) | |
# yapf: enable | |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, | |
OpenAIServing, | |
PromptAdapterPath) | |
from vllm.logger import init_logger | |
from vllm.model_executor.guided_decoding import ( | |
get_guided_decoding_logits_processor) | |
from vllm.outputs import RequestOutput | |
from vllm.sequence import Logprob | |
from vllm.tracing import (contains_trace_headers, extract_trace_headers, | |
log_tracing_disabled_warning) | |
from vllm.utils import merge_async_iterators, random_uuid | |
logger = init_logger(__name__) | |
TypeTokenIDs = List[int] | |
TypeTopLogProbs = List[Optional[Dict[int, float]]] | |
TypeCreateLogProbsFn = Callable[ | |
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] | |
class OpenAIServingCompletion(OpenAIServing): | |
def __init__( | |
self, | |
engine: AsyncLLMEngine, | |
model_config: ModelConfig, | |
served_model_names: List[str], | |
*, | |
lora_modules: Optional[List[LoRAModulePath]], | |
prompt_adapters: Optional[List[PromptAdapterPath]], | |
request_logger: Optional[RequestLogger], | |
): | |
super().__init__(engine=engine, | |
model_config=model_config, | |
served_model_names=served_model_names, | |
lora_modules=lora_modules, | |
prompt_adapters=prompt_adapters, | |
request_logger=request_logger) | |
async def create_completion(self, request: CompletionRequest, | |
raw_request: Request): | |
"""Completion API similar to OpenAI's API. | |
See https://platform.openai.com/docs/api-reference/completions/create | |
for the API specification. This API mimics the OpenAI Completion API. | |
NOTE: Currently we do not support the following feature: | |
- suffix (the language models we currently support do not support | |
suffix) | |
""" | |
error_check_ret = await self._check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
# Return error for unsupported features. | |
if request.suffix is not None: | |
return self.create_error_response( | |
"suffix is not currently supported") | |
model_name = self.served_model_names[0] | |
request_id = f"cmpl-{random_uuid()}" | |
created_time = int(time.time()) | |
# Schedule the request and get the result generator. | |
generators: List[AsyncIterator[RequestOutput]] = [] | |
try: | |
( | |
lora_request, | |
prompt_adapter_request, | |
) = self._maybe_get_adapters(request) | |
tokenizer = await self.engine.get_tokenizer(lora_request) | |
sampling_params = request.to_sampling_params() | |
decoding_config = await self.engine.get_decoding_config() | |
guided_decoding_backend = request.guided_decoding_backend \ | |
or decoding_config.guided_decoding_backend | |
guided_decode_logit_processor = ( | |
await | |
get_guided_decoding_logits_processor(guided_decoding_backend, | |
request, tokenizer)) | |
if guided_decode_logit_processor is not None: | |
if sampling_params.logits_processors is None: | |
sampling_params.logits_processors = [] | |
sampling_params.logits_processors.append( | |
guided_decode_logit_processor) | |
prompts = list( | |
self._tokenize_prompt_input_or_inputs( | |
request, | |
tokenizer, | |
request.prompt, | |
truncate_prompt_tokens=sampling_params. | |
truncate_prompt_tokens, | |
add_special_tokens=request.add_special_tokens, | |
)) | |
for i, prompt_inputs in enumerate(prompts): | |
request_id_item = f"{request_id}-{i}" | |
self._log_inputs(request_id_item, | |
prompt_inputs, | |
params=sampling_params, | |
lora_request=lora_request, | |
prompt_adapter_request=prompt_adapter_request) | |
is_tracing_enabled = await self.engine.is_tracing_enabled() | |
trace_headers = None | |
if is_tracing_enabled: | |
trace_headers = extract_trace_headers(raw_request.headers) | |
if not is_tracing_enabled and contains_trace_headers( | |
raw_request.headers): | |
log_tracing_disabled_warning() | |
generator = self.engine.generate( | |
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, | |
sampling_params, | |
request_id_item, | |
lora_request=lora_request, | |
prompt_adapter_request=prompt_adapter_request, | |
trace_headers=trace_headers, | |
) | |
generators.append(generator) | |
except ValueError as e: | |
# TODO: Use a vllm-specific Validation Error | |
return self.create_error_response(str(e)) | |
result_generator: AsyncIterator[Tuple[ | |
int, RequestOutput]] = merge_async_iterators(*generators) | |
# Similar to the OpenAI API, when n != best_of, we do not stream the | |
# results. In addition, we do not stream the results when use | |
# beam search. | |
stream = (request.stream | |
and (request.best_of is None or request.n == request.best_of) | |
and not request.use_beam_search) | |
# Streaming response | |
if stream: | |
return self.completion_stream_generator(request, | |
raw_request, | |
result_generator, | |
request_id, | |
created_time, | |
model_name, | |
num_prompts=len(prompts), | |
tokenizer=tokenizer) | |
# Non-streaming response | |
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) | |
try: | |
async for i, res in result_generator: | |
if await raw_request.is_disconnected(): | |
# Abort the request if the client disconnects. | |
await self.engine.abort(f"{request_id}-{i}") | |
return self.create_error_response("Client disconnected") | |
final_res_batch[i] = res | |
for i, final_res in enumerate(final_res_batch): | |
assert final_res is not None | |
# The output should contain the input text | |
# We did not pass it into vLLM engine to avoid being redundant | |
# with the inputs token IDs | |
if final_res.prompt is None: | |
final_res.prompt = prompts[i]["prompt"] | |
final_res_batch_checked = cast(List[RequestOutput], | |
final_res_batch) | |
response = self.request_output_to_completion_response( | |
final_res_batch_checked, | |
request, | |
request_id, | |
created_time, | |
model_name, | |
tokenizer, | |
) | |
except ValueError as e: | |
# TODO: Use a vllm-specific Validation Error | |
return self.create_error_response(str(e)) | |
# When user requests streaming but we don't stream, we still need to | |
# return a streaming response with a single event. | |
if request.stream: | |
response_json = response.model_dump_json() | |
async def fake_stream_generator() -> AsyncGenerator[str, None]: | |
yield f"data: {response_json}\n\n" | |
yield "data: [DONE]\n\n" | |
return fake_stream_generator() | |
return response | |
async def completion_stream_generator( | |
self, | |
request: CompletionRequest, | |
raw_request: Request, | |
result_generator: AsyncIterator[Tuple[int, RequestOutput]], | |
request_id: str, | |
created_time: int, | |
model_name: str, | |
num_prompts: int, | |
tokenizer: PreTrainedTokenizer, | |
) -> AsyncGenerator[str, None]: | |
num_choices = 1 if request.n is None else request.n | |
previous_texts = [""] * num_choices * num_prompts | |
previous_num_tokens = [0] * num_choices * num_prompts | |
has_echoed = [False] * num_choices * num_prompts | |
try: | |
async for prompt_idx, res in result_generator: | |
# Abort the request if the client disconnects. | |
if await raw_request.is_disconnected(): | |
await self.engine.abort(f"{request_id}-{prompt_idx}") | |
raise StopAsyncIteration() | |
for output in res.outputs: | |
i = output.index + prompt_idx * num_choices | |
# TODO(simon): optimize the performance by avoiding full | |
# text O(n^2) sending. | |
assert request.max_tokens is not None | |
if request.echo and request.max_tokens == 0: | |
# only return the prompt | |
delta_text = res.prompt | |
delta_token_ids = res.prompt_token_ids | |
out_logprobs = res.prompt_logprobs | |
has_echoed[i] = True | |
elif (request.echo and request.max_tokens > 0 | |
and not has_echoed[i]): | |
# echo the prompt and first token | |
delta_text = res.prompt + output.text | |
delta_token_ids = (res.prompt_token_ids + | |
output.token_ids) | |
out_logprobs = res.prompt_logprobs + (output.logprobs | |
or []) | |
has_echoed[i] = True | |
else: | |
# return just the delta | |
delta_text = output.text[len(previous_texts[i]):] | |
delta_token_ids = output.token_ids[ | |
previous_num_tokens[i]:] | |
out_logprobs = output.logprobs[previous_num_tokens[ | |
i]:] if output.logprobs else None | |
if request.logprobs is not None: | |
assert out_logprobs is not None, ( | |
"Did not output logprobs") | |
logprobs = self._create_completion_logprobs( | |
token_ids=delta_token_ids, | |
top_logprobs=out_logprobs, | |
num_output_top_logprobs=request.logprobs, | |
tokenizer=tokenizer, | |
initial_text_offset=len(previous_texts[i]), | |
) | |
else: | |
logprobs = None | |
previous_texts[i] = output.text | |
previous_num_tokens[i] = len(output.token_ids) | |
finish_reason = output.finish_reason | |
stop_reason = output.stop_reason | |
chunk = CompletionStreamResponse( | |
id=request_id, | |
created=created_time, | |
model=model_name, | |
choices=[ | |
CompletionResponseStreamChoice( | |
index=i, | |
text=delta_text, | |
logprobs=logprobs, | |
finish_reason=finish_reason, | |
stop_reason=stop_reason, | |
) | |
]) | |
if (request.stream_options | |
and request.stream_options.include_usage): | |
if (request.stream_options.continuous_usage_stats | |
or output.finish_reason is not None): | |
prompt_tokens = len(res.prompt_token_ids) | |
completion_tokens = len(output.token_ids) | |
usage = UsageInfo( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
total_tokens=prompt_tokens + completion_tokens, | |
) | |
if request.stream_options.continuous_usage_stats: | |
chunk.usage = usage | |
else: | |
chunk.usage = None | |
response_json = chunk.model_dump_json(exclude_unset=False) | |
yield f"data: {response_json}\n\n" | |
if (request.stream_options | |
and request.stream_options.include_usage): | |
final_usage_chunk = CompletionStreamResponse( | |
id=request_id, | |
created=created_time, | |
model=model_name, | |
choices=[], | |
usage=usage, | |
) | |
final_usage_data = (final_usage_chunk.model_dump_json( | |
exclude_unset=False, exclude_none=True)) | |
yield f"data: {final_usage_data}\n\n" | |
except ValueError as e: | |
# TODO: Use a vllm-specific Validation Error | |
data = self.create_streaming_error_response(str(e)) | |
yield f"data: {data}\n\n" | |
yield "data: [DONE]\n\n" | |
def request_output_to_completion_response( | |
self, | |
final_res_batch: List[RequestOutput], | |
request: CompletionRequest, | |
request_id: str, | |
created_time: int, | |
model_name: str, | |
tokenizer: PreTrainedTokenizer, | |
) -> CompletionResponse: | |
choices: List[CompletionResponseChoice] = [] | |
num_prompt_tokens = 0 | |
num_generated_tokens = 0 | |
for final_res in final_res_batch: | |
prompt_token_ids = final_res.prompt_token_ids | |
prompt_logprobs = final_res.prompt_logprobs | |
prompt_text = final_res.prompt | |
for output in final_res.outputs: | |
assert request.max_tokens is not None | |
if request.echo and request.max_tokens == 0: | |
token_ids = prompt_token_ids | |
out_logprobs = prompt_logprobs | |
output_text = prompt_text | |
elif request.echo and request.max_tokens > 0: | |
token_ids = prompt_token_ids + list(output.token_ids) | |
out_logprobs = (prompt_logprobs + output.logprobs | |
if request.logprobs is not None else None) | |
output_text = prompt_text + output.text | |
else: | |
token_ids = output.token_ids | |
out_logprobs = output.logprobs | |
output_text = output.text | |
if request.logprobs is not None: | |
assert out_logprobs is not None, "Did not output logprobs" | |
logprobs = self._create_completion_logprobs( | |
token_ids=token_ids, | |
top_logprobs=out_logprobs, | |
tokenizer=tokenizer, | |
num_output_top_logprobs=request.logprobs, | |
) | |
else: | |
logprobs = None | |
choice_data = CompletionResponseChoice( | |
index=len(choices), | |
text=output_text, | |
logprobs=logprobs, | |
finish_reason=output.finish_reason, | |
stop_reason=output.stop_reason, | |
) | |
choices.append(choice_data) | |
num_prompt_tokens += len(prompt_token_ids) | |
num_generated_tokens += sum( | |
len(output.token_ids) for output in final_res.outputs) | |
usage = UsageInfo( | |
prompt_tokens=num_prompt_tokens, | |
completion_tokens=num_generated_tokens, | |
total_tokens=num_prompt_tokens + num_generated_tokens, | |
) | |
return CompletionResponse( | |
id=request_id, | |
created=created_time, | |
model=model_name, | |
choices=choices, | |
usage=usage, | |
) | |
def _create_completion_logprobs( | |
self, | |
token_ids: GenericSequence[int], | |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], | |
num_output_top_logprobs: int, | |
tokenizer: PreTrainedTokenizer, | |
initial_text_offset: int = 0, | |
) -> CompletionLogProbs: | |
"""Create logprobs for OpenAI Completion API.""" | |
out_text_offset: List[int] = [] | |
out_token_logprobs: List[Optional[float]] = [] | |
out_tokens: List[str] = [] | |
out_top_logprobs: List[Optional[Dict[str, float]]] = [] | |
last_token_len = 0 | |
for i, token_id in enumerate(token_ids): | |
step_top_logprobs = top_logprobs[i] | |
if step_top_logprobs is None: | |
token = tokenizer.decode(token_id) | |
out_tokens.append(token) | |
out_token_logprobs.append(None) | |
out_top_logprobs.append(None) | |
else: | |
token = self._get_decoded_token(step_top_logprobs[token_id], | |
token_id, tokenizer) | |
token_logprob = max(step_top_logprobs[token_id].logprob, | |
-9999.0) | |
out_tokens.append(token) | |
out_token_logprobs.append(token_logprob) | |
# makes sure to add the top num_output_top_logprobs + 1 | |
# logprobs, as defined in the openai API | |
# (cf. https://github.com/openai/openai-openapi/blob/ | |
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) | |
out_top_logprobs.append({ | |
# Convert float("-inf") to the | |
# JSON-serializable float that OpenAI uses | |
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer): | |
max(top_lp[1].logprob, -9999.0) | |
for i, top_lp in enumerate(step_top_logprobs.items()) | |
if num_output_top_logprobs >= i | |
}) | |
if len(out_text_offset) == 0: | |
out_text_offset.append(initial_text_offset) | |
else: | |
out_text_offset.append(out_text_offset[-1] + last_token_len) | |
last_token_len = len(token) | |
return CompletionLogProbs( | |
text_offset=out_text_offset, | |
token_logprobs=out_token_logprobs, | |
tokens=out_tokens, | |
top_logprobs=out_top_logprobs, | |
) |