Lagent / lagent /llms /vllm_wrapper.py
Superkingjcj's picture
Upload 111 files
e679d69 verified
import asyncio
from typing import List, Union
from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM
from lagent.utils.util import filter_suffix
def asdict_completion(output):
return {
key: getattr(output, key)
for key in [
'text', 'token_ids', 'cumulative_logprob', 'logprobs',
'finish_reason', 'stop_reason'
]
}
class VllmModel(BaseLLM):
"""
A wrapper of vLLM model.
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a huggingface model.
- ii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
tp (int): tensor parallel
vllm_cfg (dict): Other kwargs for vllm model initialization.
"""
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
super().__init__(path=path, **kwargs)
from vllm import LLM
self.model = LLM(
model=self.path,
trust_remote_code=True,
tensor_parallel_size=tp,
**vllm_cfg)
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
return_dict: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from vllm import SamplingParams
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
stop_words = gen_params.pop('stop_words')
sampling_config = SamplingParams(
skip_special_tokens=skip_special_tokens,
max_tokens=max_new_tokens,
stop=stop_words,
**gen_params)
response = self.model.generate(prompt, sampling_params=sampling_config)
texts = [resp.outputs[0].text for resp in response]
# remove stop_words
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
for resp, text in zip(response, texts):
resp.outputs[0].text = text
if batched:
return [asdict_completion(resp.outputs[0])
for resp in response] if return_dict else texts
return asdict_completion(
response[0].outputs[0]) if return_dict else texts[0]
class AsyncVllmModel(AsyncBaseLLM):
"""
A asynchronous wrapper of vLLM model.
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a huggingface model.
- ii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
tp (int): tensor parallel
vllm_cfg (dict): Other kwargs for vllm model initialization.
"""
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
super().__init__(path=path, **kwargs)
from vllm import AsyncEngineArgs, AsyncLLMEngine
engine_args = AsyncEngineArgs(
model=self.path,
trust_remote_code=True,
tensor_parallel_size=tp,
**vllm_cfg)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
async def generate(self,
inputs: Union[str, List[str]],
session_ids: Union[int, List[int]] = None,
do_preprocess: bool = None,
skip_special_tokens: bool = False,
return_dict: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from vllm import SamplingParams
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
if session_ids is None:
session_ids = list(range(len(inputs)))
elif isinstance(session_ids, (int, str)):
session_ids = [session_ids]
assert len(inputs) == len(session_ids)
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
stop_words = gen_params.pop('stop_words')
sampling_config = SamplingParams(
skip_special_tokens=skip_special_tokens,
max_tokens=max_new_tokens,
stop=stop_words,
**gen_params)
async def _inner_generate(uid, text):
resp, generator = '', self.model.generate(
text, sampling_params=sampling_config, request_id=uid)
async for out in generator:
resp = out.outputs[0]
return resp
response = await asyncio.gather(*[
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
])
texts = [resp.text for resp in response]
# remove stop_words
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
for resp, text in zip(response, texts):
resp.text = text
if batched:
return [asdict_completion(resp)
for resp in response] if return_dict else texts
return asdict_completion(response[0]) if return_dict else texts[0]