# from transformers_stream_generator import init_stream_support # init_stream_support() import os import numpy as np import argparse import torch import gradio as gr from typing import Any, Iterator from typing import Iterator, List, Optional, Tuple import filelock import glob import json import time from gradio.routes import Request from gradio.utils import SyncToAsyncIterator, async_iteration from gradio.helpers import special_args import anyio from typing import AsyncGenerator, Callable, Literal, Union, cast from gradio_client.documentation import document, set_documentation_group from typing import List, Optional, Union, Dict, Tuple from tqdm.auto import tqdm from huggingface_hub import snapshot_download from gradio.components import Button from gradio.events import Dependency, EventListenerMethod from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer import types import sys from .base_engine import BaseEngine from .transformers_engine import TransformersEngine, NewGenerationMixin from ..configs import ( STREAM_CHECK_MULTIPLE, STREAM_YIELD_MULTIPLE, ) CODE_PATH = os.environ.get("CODE_PATH", "") MODEL_PATH = os.environ.get("MODEL_PATH", "") IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]" IMAGE_LENGTH = 576 MAX_PACHES = 1 BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", "")) BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else [] LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) KEYWORDS = os.environ.get("KEYWORDS", "").strip() KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] KEYWORDS = [x.lower() for x in KEYWORDS] LANG_BLOCK_MESSAGE = """Unsupported language.""" KEYWORD_BLOCK_MESSAGE = "Invalid request." def _detect_lang(text): # Disable language that may have safety risk from langdetect import detect as detect_lang dlang = None try: dlang = detect_lang(text) except Exception as e: if "No features in text." in str(e): return "en" else: return "zh" return dlang def block_lang( message: str, history: List[Tuple[str, str]] = None, ) -> str: # relieve history base block if len(BLOCK_LANGS) == 0: return False if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): return True else: _lang = _detect_lang(message) if _lang in BLOCK_LANGS: # print(f'Detect blocked {_lang}: {message}') return True else: return False def safety_check(text, history=None, ) -> Optional[str]: """ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. This provides an additional security measure to enhance safety and compliance with local regulations. """ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): return KEYWORD_BLOCK_MESSAGE if len(BLOCK_LANGS) > 0: if block_lang(text, history): return LANG_BLOCK_MESSAGE return None def safety_check_conversation_string(text, delimiter=None) -> Optional[str]: if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): return KEYWORD_BLOCK_MESSAGE if len(BLOCK_LANGS) > 0: import re delimiter = delimiter or (r"<\|im_start\|>user\n", r"<\|im_start\|>assistant\n", r"<\|im_start\|>system\n") turns = re.split(r"|".join(delimiter), text) turns = [t for t in turns if t.strip() != ''] for t in turns: if block_lang(t): return LANG_BLOCK_MESSAGE return None def is_check_safety(): return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0 def safety_check_conversation(conversation) -> Optional[str]: """ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. This provides an additional security measure to enhance safety and compliance with local regulations. """ texts = [c['content'] for c in conversation] for text in texts: if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): return KEYWORD_BLOCK_MESSAGE if len(BLOCK_LANGS) > 0: if block_lang(text): return LANG_BLOCK_MESSAGE return None class SeaLMMMv0Engine(TransformersEngine): @property def image_token(self): return IMAGE_TOKEN @property def max_position_embeddings(self) -> int: return self._model.config.max_position_embeddings @property def tokenizer(self): return self._tokenizer @property def processor(self): return self._processor def load_model(self): from transformers import AutoProcessor import sys # caution: path[0] is reserved for script path (or '' in REPL) # sys.path.append(CODE_PATH) # from examples.llm.src.models.sealmm.modeling_sealmm import ( # SeaLMMForCausalLM # ) from modeling_sealmm import (SeaLMMForCausalLM, ) model_path = MODEL_PATH print(f'Loading model from {model_path}') print(f'model_path={model_path}') if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"): os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin") self._processor = AutoProcessor.from_pretrained(model_path) self._model = SeaLMMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval() self._model.sample_old = self._model.sample self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) self._tokenizer = self._processor.tokenizer print(self._model) print(f"{self.max_position_embeddings=}") def get_multimodal_tokens(self, full_prompt, image_paths=None): num_tokens = len(self.tokenizer.encode(full_prompt)) for image_path in image_paths: num_tokens += IMAGE_LENGTH * MAX_PACHES return num_tokens def maybe_raise_safety(self, message, gen_index=-1): if is_check_safety(): if gen_index < 0: message_safety = safety_check_conversation_string(message) if message_safety is not None: raise gr.Error(message_safety) else: if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0: message_safety = safety_check_conversation_string(message) if message_safety is not None: raise gr.Error(message_safety) def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): from transformers.generation.utils import GenerationConfig from PIL import Image image_paths = kwargs.get("image_paths", None) image_paths = image_paths or [] images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None with torch.no_grad(): inputs = self.processor(prompt, images, return_tensors='pt') # inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None} inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None} num_tokens = self.get_multimodal_tokens(prompt, image_paths) # non-streaming generation # output = self._model.generate( # **inputs, # do_sample=True, # temperature=temperature, # max_new_tokens=max_tokens, # pad_token_id=self.processor.tokenizer.pad_token_id, # ) # # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True) # full_output_text = self.processor.decode(output[0], skip_special_tokens=True) # response = full_output_text.split("<|im_start|>assistant\n")[-1] # num_tokens = self.get_multimodal_tokens(prompt + response, image_paths) # print(prompt) # print(response) # print(num_tokens) # yield response, num_tokens # if i % 4 == 0 and i > 1: # message_safety = safety_check(response) # if message_safety is not None: # history = undo_history(history) # yield history, "", None # raise gr.Error(message_safety) self.maybe_raise_safety(prompt) # # ! streaming generator = self._model.generate( **inputs, do_sample=True, temperature=temperature, max_new_tokens=max_tokens, pad_token_id=self.processor.tokenizer.pad_token_id, ) out_tokens = [] response = None for index, token in enumerate(generator): out_tokens.append(token.item()) response = self.processor.tokenizer.decode(out_tokens) self.maybe_raise_safety(response, gen_index=index) yield response, num_tokens del generator if response is not None: self.maybe_raise_safety(prompt) full_text = prompt + response num_tokens = self.get_multimodal_tokens(full_text, image_paths) yield response, num_tokens