# Copyright 2022-2023 XProbe Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import os import urllib.request import uuid from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import gradio as gr from xinference.locale.utils import Locale from xinference.model import MODEL_FAMILIES, ModelSpec from xinference.core.api import SyncSupervisorAPI if TYPE_CHECKING: from xinference.types import ChatCompletionChunk, ChatCompletionMessage MODEL_TO_FAMILIES = dict( (model_family.model_name, model_family) for model_family in MODEL_FAMILIES if model_family.model_name != "baichuan" ) class GradioApp: def __init__( self, supervisor_address: str, gladiator_num: int = 2, max_model_num: int = 2, use_launched_model: bool = False, ): self._api = SyncSupervisorAPI(supervisor_address) self._gladiator_num = gladiator_num self._max_model_num = max_model_num self._use_launched_model = use_launched_model self._locale = Locale() def _create_model( self, model_name: str, model_size_in_billions: Optional[int] = None, model_format: Optional[str] = None, quantization: Optional[str] = None, ): model_uid = str(uuid.uuid1()) models = self._api.list_models() if len(models) >= self._max_model_num: self._api.terminate_model(models[0][0]) return self._api.launch_model( model_uid, model_name, model_size_in_billions, model_format, quantization ) async def generate( self, model: str, message: str, chat: List[List[str]], max_token: int, temperature: float, top_p: float, window_size: int, show_finish_reason: bool, ): if not message: yield message, chat else: try: model_ref = self._api.get_model(model) except KeyError: raise gr.Error(self._locale(f"Please create model first")) history: "List[ChatCompletionMessage]" = [] for c in chat: history.append({"role": "user", "content": c[0]}) out = c[1] finish_reason_idx = out.find(f"[{self._locale('stop reason')}: ") if finish_reason_idx != -1: out = out[:finish_reason_idx] history.append({"role": "assistant", "content": out}) if window_size != 0: history = history[-(window_size // 2) :] # chatglm only support even number of conversation history. if len(history) % 2 != 0: history = history[1:] generate_config = dict( max_tokens=max_token, temperature=temperature, top_p=top_p, stream=True, ) chat += [[message, ""]] chat_generator = await model_ref.chat( message, chat_history=history, generate_config=generate_config, ) chunk: Optional["ChatCompletionChunk"] = None async for chunk in chat_generator: assert chunk is not None delta = chunk["choices"][0]["delta"] if "content" not in delta: continue else: chat[-1][1] += delta["content"] yield "", chat if show_finish_reason and chunk is not None: chat[-1][ 1 ] += f"[{self._locale('stop reason')}: {chunk['choices'][0]['finish_reason']}]" yield "", chat def _build_chatbot(self, model_uid: str, model_name: str): with gr.Accordion(self._locale("Parameters"), open=False): max_token = gr.Slider( 128, 1024, value=128, step=1, label=self._locale("Max tokens"), info=self._locale("The maximum number of tokens to generate."), ) temperature = gr.Slider( 0.2, 1, value=0.8, step=0.01, label=self._locale("Temperature"), info=self._locale("The temperature to use for sampling."), ) top_p = gr.Slider( 0.2, 1, value=0.95, step=0.01, label=self._locale("Top P"), info=self._locale("The top-p value to use for sampling."), ) window_size = gr.Slider( 0, 50, value=10, step=1, label=self._locale("Window size"), info=self._locale("Window size of chat history."), ) show_finish_reason = gr.Checkbox( label=f"{self._locale('Show stop reason')}" ) chat = gr.Chatbot(label=model_name) text = gr.Textbox(visible=False) model_uid = gr.Textbox(model_uid, visible=False) text.change( self.generate, [ model_uid, text, chat, max_token, temperature, top_p, window_size, show_finish_reason, ], [text, chat], ) return ( text, chat, max_token, temperature, top_p, show_finish_reason, window_size, model_uid, ) def _build_chat_column(self): with gr.Column(): with gr.Row(): model_name = gr.Dropdown( choices=list(MODEL_TO_FAMILIES.keys()), label=self._locale("model name"), scale=2, ) model_format = gr.Dropdown( choices=[], interactive=False, label=self._locale("model format"), scale=2, ) model_size_in_billions = gr.Dropdown( choices=[], interactive=False, label=self._locale("model size in billions"), scale=1, ) quantization = gr.Dropdown( choices=[], interactive=False, label=self._locale("quantization"), scale=1, ) create_model = gr.Button(value=self._locale("create")) def select_model_name(model_name: str): if model_name: model_family = MODEL_TO_FAMILIES[model_name] formats = [model_family.model_format] model_sizes_in_billions = [ str(b) for b in model_family.model_sizes_in_billions ] quantizations = model_family.quantizations return ( gr.Dropdown.update( choices=formats, interactive=True, value=model_family.model_format, ), gr.Dropdown.update( choices=model_sizes_in_billions[:1], interactive=True, value=model_sizes_in_billions[0], ), gr.Dropdown.update( choices=quantizations, interactive=True, value=quantizations[0], ), ) else: return ( gr.Dropdown.update(), gr.Dropdown.update(), gr.Dropdown.update(), ) model_name.change( select_model_name, inputs=[model_name], outputs=[model_format, model_size_in_billions, quantization], ) components = self._build_chatbot("", "") model_text = components[0] chat, model_uid = components[1], components[-1] def select_model( _model_name: str, _model_format: str, _model_size_in_billions: str, _quantization: str, progress=gr.Progress(), ): model_family = MODEL_TO_FAMILIES[_model_name] cache_path, meta_path = model_family.generate_cache_path( int(_model_size_in_billions), _quantization ) if not (os.path.exists(cache_path) and os.path.exists(meta_path)): if os.path.exists(cache_path): os.remove(cache_path) url = model_family.url_generator( int(_model_size_in_billions), _quantization ) full_name = ( f"{str(model_family)}-{_model_size_in_billions}b-{_quantization}" ) try: urllib.request.urlretrieve( url, cache_path, reporthook=lambda block_num, block_size, total_size: progress( block_num * block_size / total_size, desc=self._locale("Downloading"), ), ) # write a meta file to record if download finished with open(meta_path, "w") as f: f.write(full_name) except: if os.path.exists(cache_path): os.remove(cache_path) model_uid = self._create_model( _model_name, int(_model_size_in_billions), _model_format, _quantization ) return gr.Chatbot.update( label="-".join( [_model_name, _model_size_in_billions, _model_format, _quantization] ), value=[], ), gr.Textbox.update(value=model_uid) def clear_chat( _model_name: str, _model_format: str, _model_size_in_billions: str, _quantization: str, ): full_name = "-".join( [_model_name, _model_size_in_billions, _model_format, _quantization] ) return str(uuid.uuid4()), gr.Chatbot.update( label=full_name, value=[], ) invisible_text = gr.Textbox(visible=False) create_model.click( clear_chat, inputs=[model_name, model_format, model_size_in_billions, quantization], outputs=[invisible_text, chat], ) invisible_text.change( select_model, inputs=[model_name, model_format, model_size_in_billions, quantization], outputs=[chat, model_uid], postprocess=False, ) return chat, model_text def _build_arena(self): with gr.Box(): with gr.Row(): chat_and_text = [ self._build_chat_column() for _ in range(self._gladiator_num) ] chats = [c[0] for c in chat_and_text] texts = [c[1] for c in chat_and_text] msg = gr.Textbox(label=self._locale("Input")) def update_message(text_in: str): return "", text_in, text_in msg.submit(update_message, inputs=[msg], outputs=[msg] + texts) gr.ClearButton(components=[msg] + chats + texts) def _build_single(self): chat, model_text = self._build_chat_column() msg = gr.Textbox(label=self._locale("Input")) def update_message(text_in: str): return "", text_in msg.submit(update_message, inputs=[msg], outputs=[msg, model_text]) gr.ClearButton(components=[chat, msg, model_text]) def _build_single_with_launched( self, models: List[Tuple[str, ModelSpec]], default_index: int ): uid_to_model_spec: Dict[str, ModelSpec] = dict((m[0], m[1]) for m in models) choices = [ "-".join( [ s.model_name, str(s.model_size_in_billions), s.model_format, s.quantization, ] ) for s in uid_to_model_spec.values() ] choice_to_uid = dict(zip(choices, uid_to_model_spec.keys())) model_selection = gr.Dropdown( label=self._locale("select model"), choices=choices, value=choices[default_index], ) components = self._build_chatbot( models[default_index][0], choices[default_index] ) model_text = components[0] model_uid = components[-1] chat = components[1] def select_model(model_name): uid = choice_to_uid[model_name] return gr.Chatbot.update(label=model_name), uid model_selection.change( select_model, inputs=[model_selection], outputs=[chat, model_uid] ) return chat, model_text def _build_arena_with_launched(self, models: List[Tuple[str, ModelSpec]]): chat_and_text = [] with gr.Row(): for i in range(self._gladiator_num): with gr.Column(): chat_and_text.append(self._build_single_with_launched(models, i)) chats = [c[0] for c in chat_and_text] texts = [c[1] for c in chat_and_text] msg = gr.Textbox(label=self._locale("Input")) def update_message(text_in: str): return "", text_in, text_in msg.submit(update_message, inputs=[msg], outputs=[msg] + texts) gr.ClearButton(components=[msg] + chats + texts) def build(self): if self._use_launched_model: models = self._api.list_models() with gr.Blocks() as blocks: if len(models) >= 2: with gr.Tab(self._locale("Arena")): self._build_arena_with_launched(models) with gr.Tab(self._locale("Chat")): chat, model_text = self._build_single_with_launched(models, 0) msg = gr.Textbox(label=self._locale("Input")) def update_message(text_in: str): return "", text_in msg.submit(update_message, inputs=[msg], outputs=[msg, model_text]) gr.ClearButton(components=[chat, msg, model_text]) else: with gr.Blocks() as blocks: with gr.Tab(self._locale("Chat")): self._build_single() with gr.Tab(self._locale("Arena")): self._build_arena() blocks.queue(concurrency_count=40) return blocks async def launch_xinference(): import xoscar as xo from xinference.core.service import SupervisorActor from xinference.core.api import AsyncSupervisorAPI from xinference.deploy.worker import start_worker_components pool = await xo.create_actor_pool(address="0.0.0.0", n_process=0) supervisor_address = pool.external_address await xo.create_actor( SupervisorActor, address=supervisor_address, uid=SupervisorActor.uid() ) await start_worker_components( address=supervisor_address, supervisor_address=supervisor_address ) api = AsyncSupervisorAPI(supervisor_address) supported_models = ["chatglm2", "chatglm", "vicuna-v1.3", "orca"] for model in supported_models: await api.launch_model(str(uuid.uuid4()), model) gradio_block = GradioApp(supervisor_address, use_launched_model=True).build() gradio_block.launch() if __name__ == "__main__": loop = asyncio.get_event_loop() task = loop.create_task(launch_xinference()) try: loop.run_until_complete(task) except KeyboardInterrupt: task.cancel() loop.run_until_complete(task) # avoid displaying exception-unhandled warnings task.exception()