File size: 5,785 Bytes
f04732f
a3db70a
c36d5bb
58c422e
a3db70a
 
 
 
 
f04732f
08bc998
a3db70a
 
88b9346
a3db70a
 
 
 
 
9304c73
a3db70a
 
 
 
c36d5bb
 
 
 
 
 
 
 
 
 
 
 
 
 
88b9346
 
c36d5bb
 
 
 
 
 
 
 
 
 
 
 
88b9346
 
c36d5bb
 
 
 
 
 
a3db70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b699eb0
c36d5bb
 
 
 
a3db70a
b699eb0
53eb154
a3db70a
 
 
 
4be8019
a3db70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
from modeling_llava_qwen2 import LlavaQwen2ForCausalLM
from threading import Thread
import re
import time 
from PIL import Image
import torch
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

torch.set_default_device('cuda')

tokenizer = AutoTokenizer.from_pretrained(
    'qnguyen3/nanoLLaVA',
    trust_remote_code=True)

model = LlavaQwen2ForCausalLM.from_pretrained(
    'qnguyen3/nanoLLaVA',
    torch_dtype=torch.float16,
    trust_remote_code=True)

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.keyword_ids = []
        self.max_keyword_len = 0
        for keyword in keywords:
            cur_keyword_ids = tokenizer(keyword).input_ids
            if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
                cur_keyword_ids = cur_keyword_ids[1:]
            if len(cur_keyword_ids) > self.max_keyword_len:
                self.max_keyword_len = len(cur_keyword_ids)
            self.keyword_ids.append(torch.tensor(cur_keyword_ids))
        self.tokenizer = tokenizer
        self.start_len = input_ids.shape[1]
        
    @spaces.GPU
    def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
        self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
        for keyword_id in self.keyword_ids:
            truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
            if torch.equal(truncated_output_ids, keyword_id):
                return True
        outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
        for keyword in self.keywords:
            if keyword in outputs:
                return True
        return False
        
    @spaces.GPU
    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        outputs = []
        for i in range(output_ids.shape[0]):
            outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
        return all(outputs)


@spaces.GPU
def bot_streaming(message, history):
    messages = []
    if message["files"]:
      image = message["files"][-1]["path"]
    else:
      for i, hist in enumerate(history):
        if type(hist[0])==tuple:
          image = hist[0][0]
          image_turn = i
            
    if len(history) > 0 and image is not None:
        messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
        messages.append({"role": "assistant", "content": history[1][1] })
        for human, assistant in history[2:]:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) > 0 and image is None:
        for human, assistant in history:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) == 0 and image is not None:
        messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
    elif len(history) == 0 and image is None:
        messages.append({"role": "user", "content": message['text'] })

    # if image is None:
    #     gr.Error("You need to upload an image for LLaVA to work.")
    image = Image.open(image).convert("RGB")
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True)
    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
    stop_str = '<|im_end|>'
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
    generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria])
    generated_text = ""
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
    
    buffer = ""
    for new_text in streamer:
      
      buffer += new_text
      
      generated_text_without_prompt = buffer[len(text_prompt):]
      time.sleep(0.04)
      yield generated_text_without_prompt


demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
                                                                      {"text": "How to make this pastry?", "files":["./baklava.png"]}], 
                        description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True)