Spaces:
Running
Running
import asyncio | |
from typing import List, Union | |
from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM | |
from lagent.utils.util import filter_suffix | |
def asdict_completion(output): | |
return { | |
key: getattr(output, key) | |
for key in [ | |
'text', 'token_ids', 'cumulative_logprob', 'logprobs', | |
'finish_reason', 'stop_reason' | |
] | |
} | |
class VllmModel(BaseLLM): | |
""" | |
A wrapper of vLLM model. | |
Args: | |
path (str): The path to the model. | |
It could be one of the following options: | |
- i) A local directory path of a huggingface model. | |
- ii) The model_id of a model hosted inside a model repo | |
on huggingface.co, such as "internlm/internlm-chat-7b", | |
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
and so on. | |
tp (int): tensor parallel | |
vllm_cfg (dict): Other kwargs for vllm model initialization. | |
""" | |
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): | |
super().__init__(path=path, **kwargs) | |
from vllm import LLM | |
self.model = LLM( | |
model=self.path, | |
trust_remote_code=True, | |
tensor_parallel_size=tp, | |
**vllm_cfg) | |
def generate(self, | |
inputs: Union[str, List[str]], | |
do_preprocess: bool = None, | |
skip_special_tokens: bool = False, | |
return_dict: bool = False, | |
**kwargs): | |
"""Return the chat completions in non-stream mode. | |
Args: | |
inputs (Union[str, List[str]]): input texts to be completed. | |
do_preprocess (bool): whether pre-process the messages. Default to | |
True, which means chat_template will be applied. | |
skip_special_tokens (bool): Whether or not to remove special tokens | |
in the decoding. Default to be False. | |
Returns: | |
(a list of/batched) text/chat completion | |
""" | |
from vllm import SamplingParams | |
batched = True | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
batched = False | |
prompt = inputs | |
gen_params = self.update_gen_params(**kwargs) | |
max_new_tokens = gen_params.pop('max_new_tokens') | |
stop_words = gen_params.pop('stop_words') | |
sampling_config = SamplingParams( | |
skip_special_tokens=skip_special_tokens, | |
max_tokens=max_new_tokens, | |
stop=stop_words, | |
**gen_params) | |
response = self.model.generate(prompt, sampling_params=sampling_config) | |
texts = [resp.outputs[0].text for resp in response] | |
# remove stop_words | |
texts = filter_suffix(texts, self.gen_params.get('stop_words')) | |
for resp, text in zip(response, texts): | |
resp.outputs[0].text = text | |
if batched: | |
return [asdict_completion(resp.outputs[0]) | |
for resp in response] if return_dict else texts | |
return asdict_completion( | |
response[0].outputs[0]) if return_dict else texts[0] | |
class AsyncVllmModel(AsyncBaseLLM): | |
""" | |
A asynchronous wrapper of vLLM model. | |
Args: | |
path (str): The path to the model. | |
It could be one of the following options: | |
- i) A local directory path of a huggingface model. | |
- ii) The model_id of a model hosted inside a model repo | |
on huggingface.co, such as "internlm/internlm-chat-7b", | |
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
and so on. | |
tp (int): tensor parallel | |
vllm_cfg (dict): Other kwargs for vllm model initialization. | |
""" | |
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): | |
super().__init__(path=path, **kwargs) | |
from vllm import AsyncEngineArgs, AsyncLLMEngine | |
engine_args = AsyncEngineArgs( | |
model=self.path, | |
trust_remote_code=True, | |
tensor_parallel_size=tp, | |
**vllm_cfg) | |
self.model = AsyncLLMEngine.from_engine_args(engine_args) | |
async def generate(self, | |
inputs: Union[str, List[str]], | |
session_ids: Union[int, List[int]] = None, | |
do_preprocess: bool = None, | |
skip_special_tokens: bool = False, | |
return_dict: bool = False, | |
**kwargs): | |
"""Return the chat completions in non-stream mode. | |
Args: | |
inputs (Union[str, List[str]]): input texts to be completed. | |
do_preprocess (bool): whether pre-process the messages. Default to | |
True, which means chat_template will be applied. | |
skip_special_tokens (bool): Whether or not to remove special tokens | |
in the decoding. Default to be False. | |
Returns: | |
(a list of/batched) text/chat completion | |
""" | |
from vllm import SamplingParams | |
batched = True | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
batched = False | |
if session_ids is None: | |
session_ids = list(range(len(inputs))) | |
elif isinstance(session_ids, (int, str)): | |
session_ids = [session_ids] | |
assert len(inputs) == len(session_ids) | |
prompt = inputs | |
gen_params = self.update_gen_params(**kwargs) | |
max_new_tokens = gen_params.pop('max_new_tokens') | |
stop_words = gen_params.pop('stop_words') | |
sampling_config = SamplingParams( | |
skip_special_tokens=skip_special_tokens, | |
max_tokens=max_new_tokens, | |
stop=stop_words, | |
**gen_params) | |
async def _inner_generate(uid, text): | |
resp, generator = '', self.model.generate( | |
text, sampling_params=sampling_config, request_id=uid) | |
async for out in generator: | |
resp = out.outputs[0] | |
return resp | |
response = await asyncio.gather(*[ | |
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) | |
]) | |
texts = [resp.text for resp in response] | |
# remove stop_words | |
texts = filter_suffix(texts, self.gen_params.get('stop_words')) | |
for resp, text in zip(response, texts): | |
resp.text = text | |
if batched: | |
return [asdict_completion(resp) | |
for resp in response] if return_dict else texts | |
return asdict_completion(response[0]) if return_dict else texts[0] | |