Spaces:
Running
on
Zero
Running
on
Zero
# A mirror to gradio launch stream | |
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group | |
""" | |
Load FasterLlama with original VLLM codebase, | |
require changing config names to LlamaForCausalLM | |
tensor_parallel must == 1 | |
""" | |
import os | |
import numpy as np | |
import argparse | |
import torch | |
import gradio as gr | |
from typing import Any, Iterator | |
from typing import Iterator, List, Optional, Tuple | |
import filelock | |
import glob | |
import json | |
from gradio_client.documentation import document, set_documentation_group | |
from typing import List, Optional, Union, Dict, Tuple | |
from tqdm.auto import tqdm | |
from huggingface_hub import snapshot_download | |
# @@ constants ================ | |
DEBUG = bool(int(os.environ.get("DEBUG", "1"))) | |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1"))) | |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1")) | |
DTYPE = os.environ.get("DTYPE", "bfloat16") | |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH | |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0"))) | |
# ! uploaded model path, will be downloaded to MODEL_PATH | |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a") | |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a") | |
# gradio config | |
PORT = int(os.environ.get("PORT", "7860")) | |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1")) | |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048")) | |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1")) | |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4")) | |
""" | |
TODO: | |
need to upload the model as hugginface/models/seal_13b_a | |
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets | |
set | |
MODEL_REPO_ID=hugginface/models/seal_13b_a | |
# if persistent, then export the following | |
HF_HOME=/data/.huggingface | |
TRANSFORMERS_CACHE=/data/.huggingface | |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a | |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a | |
# if not persistent | |
MODEL_PATH=./seal-13b-chat-a | |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a | |
# download will auto detect and get the most updated one | |
if DOWNLOAD_SNAPSHOT: | |
print(f'Download from HF_MODEL_NAME={HF_MODEL_NAME} -> {MODEL_PATH}') | |
snapshot_download(HF_MODEL_NAME, local_dir=MODEL_PATH) | |
elif not DEBUG: | |
assert os.path.exists(MODEL_PATH), f'{MODEL_PATH} not found and no snapshot download' | |
""" | |
# ============================== | |
print(f'DEBUG mode: {DEBUG}') | |
if DTYPE == "bfloat16" and not DEBUG: | |
try: | |
compute_capability = torch.cuda.get_device_capability() | |
if compute_capability[0] < 8: | |
gpu_name = torch.cuda.get_device_name() | |
print( | |
"Bfloat16 is only supported on GPUs with compute capability " | |
f"of at least 8.0. Your {gpu_name} GPU has compute capability " | |
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16") | |
DTYPE = "float16" | |
except Exception as e: | |
print(f'Unable to obtain compute_capability: {e}') | |
# @@ constants ================ | |
if not DEBUG: | |
# vllm import | |
from vllm import LLM, SamplingParams | |
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | |
from vllm.engine.arg_utils import EngineArgs | |
from vllm.engine.llm_engine import LLMEngine | |
from vllm.outputs import RequestOutput | |
from vllm.sampling_params import SamplingParams | |
from vllm.utils import Counter | |
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, | |
SequenceGroupMetadata, SequenceOutputs, | |
SequenceStatus) | |
# ! reconfigure vllm to faster llama | |
from vllm.model_executor.model_loader import _MODEL_REGISTRY | |
from vllm.model_executor.models import LlamaForCausalLM | |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM | |
def _detect_lang(text): | |
from langdetect import detect as detect_lang | |
from langdetect.detector import LangDetectException | |
dlang = None | |
try: | |
dlang = detect_lang(text) | |
except Exception as e: | |
# No features in text. | |
print(f'Error: {e}') | |
if "No features in text." in str(e): | |
return "en" | |
else: | |
return "zh" | |
return dlang | |
def hf_model_weights_iterator( | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
) -> Iterator[Tuple[str, torch.Tensor]]: | |
from vllm.model_executor.weight_utils import Disabledtqdm | |
# Prepare file lock directory to prevent multiple processes from | |
# downloading the same model weights at the same time. | |
lock_dir = cache_dir if cache_dir is not None else "/tmp" | |
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" | |
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) | |
# Download model weights from huggingface. | |
is_local = os.path.isdir(model_name_or_path) | |
if not is_local: | |
with lock: | |
hf_folder = snapshot_download(model_name_or_path, | |
allow_patterns="*.bin", | |
cache_dir=cache_dir, | |
local_files_only=True, | |
tqdm_class=Disabledtqdm) | |
else: | |
hf_folder = model_name_or_path | |
hf_bin_files = [ | |
# x for x in glob.glob(os.path.join(hf_folder, "*.bin")) | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin")) | |
if not x.endswith("training_args.bin") | |
] | |
hf_safetensors_files = [ | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors")) | |
if not x.endswith("training_args.bin") | |
] | |
if use_np_cache: | |
# Convert the model weights from torch tensors to numpy arrays for | |
# faster loading. | |
np_folder = os.path.join(hf_folder, "np") | |
os.makedirs(np_folder, exist_ok=True) | |
weight_names_file = os.path.join(np_folder, "weight_names.json") | |
with lock: | |
if not os.path.exists(weight_names_file): | |
weight_names = [] | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "wb") as f: | |
np.save(f, param.cpu().detach().numpy()) | |
weight_names.append(name) | |
with open(weight_names_file, "w") as f: | |
json.dump(weight_names, f) | |
with open(weight_names_file, "r") as f: | |
weight_names = json.load(f) | |
for name in weight_names: | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "rb") as f: | |
param = np.load(f) | |
yield name, torch.from_numpy(param) | |
else: | |
if len(hf_bin_files) > 0: | |
print(F'Load bin files: {hf_bin_files}') | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
elif len(hf_safetensors_files) > 0: | |
print(F'Load safetensor files: {hf_safetensors_files}') | |
from safetensors.torch import load_file | |
for safe_file in hf_safetensors_files: | |
# state = torch.load(bin_file, map_location="cpu") | |
state = load_file(safe_file) | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
else: | |
raise ValueError(f'no files available either bin or safe') | |
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: | |
"""convert PySafeSlice object from safetensors to torch.Tensor | |
PySafeSlice object supports indexing, which is done before loading the | |
actual tensor and can reduce the amount of memory being read into the | |
memory. However, it does not support more advanced functionalities | |
like `.view()` or `.t()`. Therefore, if we need to modify the loaded | |
tensor with these more complicated operators, we need to convert to | |
tensor first. | |
""" | |
if not isinstance(x, torch.Tensor): | |
x = x[:] | |
return x | |
def load_padded_tensor_parallel_vocab( | |
param: torch.Tensor, | |
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` | |
tensor_model_parallel_rank: int, | |
) -> None: | |
shard_size = param.shape[0] | |
start_idx = tensor_model_parallel_rank * shard_size | |
end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
loaded_weight = loaded_weight[start_idx:end_idx] | |
loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
param[:loaded_weight.shape[0]].copy_(loaded_weight) | |
def llama_load_weights( | |
self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
load_format: str = "auto", | |
# load_format: str = "pt", | |
revision: Optional[str] = None | |
): | |
from vllm.model_executor.weight_utils import ( | |
load_tensor_parallel_weights | |
) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
tp_size = get_tensor_model_parallel_world_size() | |
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
q_proj_shard_size = (self.config.hidden_size // tp_size) | |
kv_proj_shard_size = (self.config.hidden_size // | |
self.config.num_attention_heads * | |
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size) | |
attention_weight_specs = [ | |
# (weight_name, shard_size, offset) | |
("q_proj", q_proj_shard_size, 0), | |
("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
("v_proj", kv_proj_shard_size, | |
q_proj_shard_size + kv_proj_shard_size), | |
] | |
state_dict = self.state_dict() | |
need_to_load = len(state_dict) | |
loaded = 0 | |
iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
for name, loaded_weight in iterator: | |
if "rotary_emb.inv_freq" in name: | |
continue | |
if "embed_tokens" in name or "lm_head" in name: | |
param = state_dict[name] | |
# Consider padding in the vocab size. | |
padded_vocab_size = (param.shape[0] * tp_size) | |
# num_extra_rows = padded_vocab_size - self.config.vocab_size | |
num_extra_rows = padded_vocab_size - loaded_weight.size(0) | |
load_size = loaded_weight.size() | |
extra_rows = torch.empty(num_extra_rows, | |
loaded_weight.shape[1]) | |
extra_rows = extra_rows.to(loaded_weight) | |
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | |
if num_extra_rows > 0: | |
print(f'Add empty to {num_extra_rows} extra row for {name}') | |
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}') | |
is_attention_weight = False | |
for weight_name, shard_size, offset in attention_weight_specs: | |
if weight_name not in name or "qkv_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "qkv_proj")] | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[offset:offset + shard_size] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 3 | |
is_attention_weight = True | |
break | |
if is_attention_weight: | |
continue | |
# ! qkv_proj is sharded differently if concatenated into qkv | |
# qkv: qqqq kkkk vvvv | |
# lweight: qq0qq1 kk0kk1 vv0vv1 | |
# q_shard_size: hidden_size // tp_size = qq | |
# qkv_s0: qq0_kk0_vv0 | |
# qkv_s1: qq1_kk1_vv1 | |
if "qkv_proj" in name: | |
param = state_dict[name] | |
# loaded_weight | |
qsize = self.config.hidden_size | |
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) | |
q_offsets = ( | |
q_proj_shard_size * tensor_model_parallel_rank, | |
q_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
k_offsets = ( | |
qsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
v_offsets = ( | |
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[q_offsets[0]:q_offsets[1]], | |
loaded_weight[k_offsets[0]:k_offsets[1]], | |
loaded_weight[v_offsets[0]:v_offsets[1]], | |
], 0 | |
) | |
# print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}') | |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}' | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_attention_weight = True | |
if is_attention_weight: | |
continue | |
is_gate_up_weight = False | |
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
if weight_name not in name or "gate_up_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
shard_size = param.shape[0] // 2 | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[shard_size * stride_id:shard_size * | |
(stride_id + 1)] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 2 | |
is_gate_up_weight = True | |
break | |
if is_gate_up_weight: | |
continue | |
if "gate_up_proj" in name: | |
param = state_dict[name] | |
shard_size = param.shape[0] // 2 | |
intermediate_size = self.config.intermediate_size | |
g_offsets = ( | |
shard_size * tensor_model_parallel_rank, | |
shard_size * (tensor_model_parallel_rank + 1) | |
) | |
u_offsets = ( | |
intermediate_size + shard_size * tensor_model_parallel_rank, | |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1) | |
) | |
# print(f'{name} {param.size()} | {g_offsets} | {u_offsets}') | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[g_offsets[0]:g_offsets[1]], | |
loaded_weight[u_offsets[0]:u_offsets[1]], | |
], 0 | |
) | |
assert param.shape == _loaded_weight.shape | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_gate_up_weight = True | |
if is_gate_up_weight: | |
continue | |
param = state_dict[name] | |
load_tensor_parallel_weights(param, loaded_weight, name, | |
self._column_parallel_weights, | |
self._row_parallel_weights, | |
tensor_model_parallel_rank) | |
loaded += 1 | |
if np.abs(loaded - need_to_load) < 0.01: | |
print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
else: | |
print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights | |
if not DEBUG: | |
LlamaForCausalLM.load_weights = llama_load_weights | |
# ! ================================================================== | |
set_documentation_group("component") | |
DTYPES = { | |
'float16': torch.float16, | |
'bfloat16': torch.bfloat16 | |
} | |
llm = None | |
demo = None | |
BOS_TOKEN = '<s>' | |
EOS_TOKEN = '</s>' | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \ | |
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ | |
that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ | |
correct. If you don't know the answer to a question, please don't share false information. | |
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \ | |
Your response should adapt to the norms and customs of the respective language and culture. | |
""" | |
RES_PRINTED = False | |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN): | |
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}" | |
def llama_chat_multiturn_sys_input_seq_constructor( | |
message: str, | |
history: List[Tuple[str, str]], | |
sys_prompt=SYSTEM_PROMPT_1, | |
bos_token=BOS_TOKEN, | |
eos_token=EOS_TOKEN, | |
): | |
""" | |
``` | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
``` | |
""" | |
text = '' | |
for i, (prompt, res) in enumerate(history): | |
if i == 0: | |
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {prompt} {E_INST}" | |
if res is not None: | |
text += f" {res} {eos_token} " | |
if len(history) == 0 or text.strip() == '': | |
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {message} {E_INST}" | |
return text | |
class ChatBot(gr.Chatbot): | |
def _postprocess_chat_messages( | |
self, chat_message | |
): | |
x = super()._postprocess_chat_messages(chat_message) | |
if isinstance(x, str): | |
x = x.strip().replace("\n", "<br>") | |
return x | |
# gr.ChatInterface | |
from gradio.components import Button | |
from gradio.events import Dependency, EventListenerMethod | |
def _setup_stop_events( | |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency | |
) -> None: | |
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers] | |
if self.stop_btn and self.is_generator: | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: ( | |
Button.update(visible=False), | |
Button.update(visible=True), | |
), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: (Button.update(visible=True), Button.update(visible=False)), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
else: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: Button.update(visible=True), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: Button.update(visible=False), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
self.stop_btn.click( | |
None, | |
None, | |
None, | |
cancels=event_to_cancel, | |
api_name=False, | |
) | |
else: | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: Button.update(interactive=False), | |
None, | |
[self.submit_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: Button.update(interactive=True), | |
None, | |
[self.submit_btn], | |
api_name=False, | |
queue=False, | |
) | |
gr.ChatInterface._setup_stop_events = _setup_stop_events | |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str: | |
global llm | |
assert llm is not None | |
temperature = float(temperature) | |
max_tokens = int(max_tokens) | |
if system_prompt.strip() != '': | |
# chat version, add system prompt | |
message = llama_chat_sys_input_seq_constructor( | |
message.strip(), | |
sys_prompt=system_prompt | |
) | |
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens) | |
gen = llm.generate(message, sampling_params) | |
out = gen[0].outputs[0].text | |
return f'{out}' | |
def vllm_abort(self: Any): | |
scheduler = self.llm_engine.scheduler | |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]: | |
for seq_group in state_queue: | |
# if seq_group.request_id == request_id: | |
# Remove the sequence group from the state queue. | |
state_queue.remove(seq_group) | |
for seq in seq_group.seqs: | |
if seq.is_finished(): | |
continue | |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED) | |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]: | |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]: | |
# Initialize tqdm. | |
if use_tqdm: | |
num_requests = self.llm_engine.get_num_unfinished_requests() | |
pbar = tqdm(total=num_requests, desc="Processed prompts") | |
# Run the engine. | |
outputs: Dict[str, RequestOutput] = {} | |
while self.llm_engine.has_unfinished_requests(): | |
step_outputs = self.llm_engine.step() | |
for output in step_outputs: | |
outputs[output.request_id] = output | |
# outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
if len(outputs) > 0: | |
yield outputs | |
# if use_tqdm: | |
# pbar.close() | |
# Sort the outputs by request ID. | |
# This is necessary because some requests may be finished earlier than | |
# its previous requests. | |
# outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
# return outputs | |
def vllm_generate_stream( | |
self: Any, | |
prompts: Optional[Union[str, List[str]]] = None, | |
sampling_params: Optional[Any] = None, | |
prompt_token_ids: Optional[List[List[int]]] = None, | |
use_tqdm: bool = False, | |
) -> Dict[str, Any]: | |
"""Generates the completions for the input prompts. | |
NOTE: This class automatically batches the given prompts, considering | |
the memory constraint. For the best performance, put all of your prompts | |
into a single list and pass it to this method. | |
Args: | |
prompts: A list of prompts to generate completions for. | |
sampling_params: The sampling parameters for text generation. If | |
None, we use the default sampling parameters. | |
prompt_token_ids: A list of token IDs for the prompts. If None, we | |
use the tokenizer to convert the prompts to token IDs. | |
use_tqdm: Whether to use tqdm to display the progress bar. | |
Returns: | |
A list of `RequestOutput` objects containing the generated | |
completions in the same order as the input prompts. | |
""" | |
if prompts is None and prompt_token_ids is None: | |
raise ValueError("Either prompts or prompt_token_ids must be " | |
"provided.") | |
if isinstance(prompts, str): | |
# Convert a single prompt to a list. | |
prompts = [prompts] | |
if prompts is not None and prompt_token_ids is not None: | |
if len(prompts) != len(prompt_token_ids): | |
raise ValueError("The lengths of prompts and prompt_token_ids " | |
"must be the same.") | |
if sampling_params is None: | |
# Use default sampling params. | |
sampling_params = SamplingParams() | |
# Add requests to the engine. | |
if prompts is not None: | |
num_requests = len(prompts) | |
else: | |
num_requests = len(prompt_token_ids) | |
for i in range(num_requests): | |
prompt = prompts[i] if prompts is not None else None | |
if prompt_token_ids is None: | |
token_ids = None | |
else: | |
token_ids = prompt_token_ids[i] | |
self._add_request(prompt, sampling_params, token_ids) | |
# return self._run_engine(use_tqdm) | |
yield from _vllm_run_engine(self, use_tqdm) | |
# def chat_response_stream( | |
# message: str, | |
# history: List[Tuple[str, str]], | |
# temperature: float, | |
# max_tokens: int, | |
# frequency_penalty: float, | |
# system_prompt: str | |
# ) -> str: | |
# global llm, RES_PRINTED | |
# assert llm is not None | |
# # force removing all | |
# vllm_abort(llm) | |
# temperature = float(temperature) | |
# frequency_penalty = float(frequency_penalty) | |
# max_tokens = int(max_tokens) | |
# if system_prompt.strip() != '': | |
# # chat version, add system prompt | |
# message = llama_chat_sys_input_seq_constructor( | |
# message.strip(), | |
# sys_prompt=system_prompt | |
# ) | |
# sampling_params = SamplingParams( | |
# temperature=temperature, max_tokens=max_tokens, | |
# frequency_penalty=frequency_penalty, | |
# ) | |
# cur_out = None | |
# for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)): | |
# if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0: | |
# yield cur_out | |
# assert len(gen) == 1, f'{gen}' | |
# item = next(iter(gen.values())) | |
# cur_out = item.outputs[0].text | |
# if not RES_PRINTED: | |
# print(f'{message}<<<{cur_out}>>>') | |
# RES_PRINTED = True | |
# if cur_out is not None: | |
# yield cur_out | |
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation. | |
抱歉,目前不支持中文。 请清除聊天框以进行新对话。""" | |
def block_zh( | |
message: str, | |
history: List[Tuple[str, str]] | |
) -> str: | |
# if any((BLOCK_MESSAGE in x[0].strip() or BLOCK_MESSAGE in x[1].strip()) for x in history): | |
if any((BLOCK_MESSAGE in x[1].strip()) for x in history): | |
return True | |
elif 'zh' in _detect_lang(message): | |
print(f'Detect zh: {message}') | |
return True | |
# ! optionally detect every responses message | |
else: | |
return False | |
# 抱歉,目前不支持中文。 | |
def chat_response_stream_multiturn( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float, | |
max_tokens: int, | |
frequency_penalty: float, | |
system_prompt: Optional[str] = SYSTEM_PROMPT_1 | |
) -> str: | |
"""Build multi turn | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
message is incoming prompt | |
history don't have the current messauge | |
""" | |
global llm, RES_PRINTED | |
assert llm is not None | |
assert system_prompt.strip() != '', f'system prompt is empty' | |
# force removing all | |
vllm_abort(llm) | |
temperature = float(temperature) | |
frequency_penalty = float(frequency_penalty) | |
max_tokens = int(max_tokens) | |
message = message.strip() | |
# detect_ = _detect_lang(message) | |
# print(f'Message language: {detect_}') | |
# ! lang detect | |
if BLOCK_ZH: | |
if block_zh(message, history): | |
yield BLOCK_MESSAGE | |
return | |
# history.append([message, None]) | |
# history will be appended with message later on | |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor( | |
message, history, sys_prompt=system_prompt | |
) | |
# print(full_prompt) | |
sampling_params = SamplingParams( | |
temperature=temperature, max_tokens=max_tokens, | |
frequency_penalty=frequency_penalty, | |
) | |
cur_out = None | |
# for gen in vllm_generate_stream(llm, full_prompt, sampling_params): | |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)): | |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0: | |
yield cur_out | |
assert len(gen) == 1, f'{gen}' | |
item = next(iter(gen.values())) | |
cur_out = item.outputs[0].text | |
# if not RES_PRINTED: | |
print(f'{full_prompt}<<<{cur_out}>>>\n') | |
# RES_PRINTED = True | |
if cur_out is not None: | |
yield cur_out | |
# print(f'Output: {_detect_lang(cur_out)}') | |
if BLOCK_ZH: | |
if "zh" in _detect_lang(cur_out): | |
yield BLOCK_MESSAGE | |
def debug_chat_response_echo( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = 4096, | |
frequency_penalty: float = 0.4, | |
system_prompt: str = SYSTEM_PROMPT_1, | |
) -> str: | |
import time | |
time.sleep(0.5) | |
yield f"repeat: {message}" | |
# ============ CONSTANT ============ | |
# https://github.com/gradio-app/gradio/issues/884 | |
MODEL_NAME = "SeaL-13B" | |
MODEL_TITLE = "SeaL-13B - An Assistant for South East Asian Languages" | |
# ! add icon: "<img src='file/lion.jpg' alt='image One'>" | |
MODEL_DESC = """ | |
<span style="font-size: larger"> | |
This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭. | |
</span> | |
""".strip() | |
# <br> | |
cite_markdown = """ | |
### Citation | |
If you find our project useful, hope you can star our repo and cite our paper as follows: | |
``` | |
@article{damonlpsg2023seallm, | |
author = {???}, | |
title = {SeaL: A language model for South East Asian Languages}, | |
year = 2023, | |
} | |
``` | |
""" | |
warning_markdown = """ | |
### Warning: | |
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span> | |
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \ | |
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span> | |
""" | |
path_markdown = """ | |
#### Model path: | |
{model_path} | |
""" | |
def check_model_path(model_path) -> str: | |
assert os.path.exists(model_path), f'{model_path} not found' | |
ckpt_info = "None" | |
if os.path.isdir(model_path): | |
if os.path.exists(f'{model_path}/info.txt'): | |
with open(f'{model_path}/info.txt', 'r') as f: | |
ckpt_info = f.read() | |
print(f'Checkpoint info:\n{ckpt_info}\n-----') | |
else: | |
print(f'info.txt not found in {model_path}') | |
print(f'model path dir: {list(os.listdir(model_path))}') | |
return ckpt_info | |
def launch(): | |
global demo, llm, DEBUG | |
model_desc = MODEL_DESC | |
model_path = MODEL_PATH | |
model_title = MODEL_TITLE | |
hf_model_name = HF_MODEL_NAME | |
tensor_parallel = TENSOR_PARALLEL | |
assert tensor_parallel > 0 , f'{tensor_parallel} invalid' | |
dtype = DTYPE | |
sys_prompt = SYSTEM_PROMPT_1 | |
max_tokens = MAX_TOKENS | |
temperature = TEMPERATURE | |
frequence_penalty = FREQUENCE_PENALTY | |
ckpt_info = "None" | |
print( | |
f'Launch config: {model_path=} / {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} ' | |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} ' | |
f'\n| frequence_penalty={frequence_penalty} ' | |
f'\n| temperature={temperature} ' | |
f'\n| hf_model_name={hf_model_name} ' | |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} ' | |
f'\nsys={SYSTEM_PROMPT_1}' | |
f'\ndesc={model_desc}' | |
) | |
if DEBUG: | |
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original" | |
response_fn = debug_chat_response_echo | |
print(f'Creating in DEBUG MODE') | |
else: | |
# ! load the model | |
import vllm | |
print(F'VLLM: {vllm.__version__}') | |
if DOWNLOAD_SNAPSHOT: | |
print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}') | |
snapshot_download(hf_model_name, local_dir=model_path) | |
assert os.path.exists(model_path), f'{model_path} not found and no snapshot download' | |
ckpt_info = check_model_path(model_path) | |
print(f'Load path: {model_path} | {ckpt_info}') | |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel) | |
print(f'Use system prompt:\n{sys_prompt}') | |
response_fn = chat_response_stream_multiturn | |
print(F'respond: {response_fn}') | |
demo = gr.ChatInterface( | |
response_fn, | |
chatbot=ChatBot( | |
label=MODEL_NAME, | |
bubble_full_width=False, | |
latex_delimiters=[ | |
{ "left": "$", "right": "$", "display": False}, | |
{ "left": "$$", "right": "$$", "display": True}, | |
] | |
), | |
textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
# ! consider preventing the stop button | |
stop_btn=None, | |
title=f"{model_title}", | |
description=f"{model_desc}", | |
# ! decide if can change the system prompt. | |
additional_inputs=[ | |
gr.Number(value=temperature, label='Temperature (higher -> more random)'), | |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'), | |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8) | |
], | |
) | |
with demo: | |
gr.Markdown(warning_markdown) | |
gr.Markdown(cite_markdown) | |
gr.Markdown(path_markdown.format(model_path=model_path)) | |
demo.queue() | |
demo.launch(server_port=PORT) | |
def main(): | |
# launch(parser.parse_args()) | |
launch() | |
if __name__ == "__main__": | |
main() | |
""" | |
export CUDA_VISIBLE_DEVICES=0 | |
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000 | |
export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster | |
export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp | |
export DEBUG=0 | |
export CUDA_VISIBLE_DEVICES=0 | |
export MODEL_PATH=seal_13b_a | |
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW12k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.SeaV2Cq13M.SeaV2Cq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_6000 | |
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.Sft2Censor.Sft2Censor.m4k.b8.lr1e5.linear.wa0k.ms1144k.grac1.se1.6g.v4c.zfsdp/step_2000 | |
export PORT=8799 | |
export BLOCK_ZH=1 | |
python app.py | |
DEBUG=1 python app.py | |
""" |