|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import os |
|
from threading import Thread |
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence |
|
|
|
from ..extras.misc import torch_gc |
|
from ..hparams import get_infer_args |
|
from .hf_engine import HuggingfaceEngine |
|
from .vllm_engine import VllmEngine |
|
|
|
|
|
if TYPE_CHECKING: |
|
from ..data.mm_plugin import ImageInput, VideoInput |
|
from .base_engine import BaseEngine, Response |
|
|
|
|
|
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: |
|
asyncio.set_event_loop(loop) |
|
loop.run_forever() |
|
|
|
|
|
class ChatModel: |
|
r""" |
|
General class for chat models. Backed by huggingface or vllm engines. |
|
|
|
Supports both sync and async methods. |
|
Sync methods: chat(), stream_chat() and get_scores(). |
|
Async methods: achat(), astream_chat() and aget_scores(). |
|
""" |
|
|
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: |
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) |
|
self.engine_type = model_args.infer_backend |
|
if model_args.infer_backend == "huggingface": |
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) |
|
elif model_args.infer_backend == "vllm": |
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) |
|
else: |
|
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) |
|
|
|
self._loop = asyncio.new_event_loop() |
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) |
|
self._thread.start() |
|
|
|
def chat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> List["Response"]: |
|
r""" |
|
Gets a list of responses of the chat model. |
|
""" |
|
task = asyncio.run_coroutine_threadsafe( |
|
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop |
|
) |
|
return task.result() |
|
|
|
async def achat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> List["Response"]: |
|
r""" |
|
Asynchronously gets a list of responses of the chat model. |
|
""" |
|
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) |
|
|
|
def stream_chat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> Generator[str, None, None]: |
|
r""" |
|
Gets the response token-by-token of the chat model. |
|
""" |
|
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) |
|
while True: |
|
try: |
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) |
|
yield task.result() |
|
except StopAsyncIteration: |
|
break |
|
|
|
async def astream_chat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> AsyncGenerator[str, None]: |
|
r""" |
|
Asynchronously gets the response token-by-token of the chat model. |
|
""" |
|
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): |
|
yield new_token |
|
|
|
def get_scores( |
|
self, |
|
batch_input: List[str], |
|
**input_kwargs, |
|
) -> List[float]: |
|
r""" |
|
Gets a list of scores of the reward model. |
|
""" |
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) |
|
return task.result() |
|
|
|
async def aget_scores( |
|
self, |
|
batch_input: List[str], |
|
**input_kwargs, |
|
) -> List[float]: |
|
r""" |
|
Asynchronously gets a list of scores of the reward model. |
|
""" |
|
return await self.engine.get_scores(batch_input, **input_kwargs) |
|
|
|
|
|
def run_chat() -> None: |
|
if os.name != "nt": |
|
try: |
|
import readline |
|
except ImportError: |
|
print("Install `readline` for a better experience.") |
|
|
|
chat_model = ChatModel() |
|
messages = [] |
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") |
|
|
|
while True: |
|
try: |
|
query = input("\nUser: ") |
|
except UnicodeDecodeError: |
|
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") |
|
continue |
|
except Exception: |
|
raise |
|
|
|
if query.strip() == "exit": |
|
break |
|
|
|
if query.strip() == "clear": |
|
messages = [] |
|
torch_gc() |
|
print("History has been removed.") |
|
continue |
|
|
|
messages.append({"role": "user", "content": query}) |
|
print("Assistant: ", end="", flush=True) |
|
|
|
response = "" |
|
for new_text in chat_model.stream_chat(messages): |
|
print(new_text, end="", flush=True) |
|
response += new_text |
|
print() |
|
messages.append({"role": "assistant", "content": response}) |
|
|